feat(Google): Support all Text/Chat Models, Response streaming, PaLM -> Google 🤖 (#1316)

* feat: update PaLM icons

* feat: add additional google models

* POC: formatting inputs for Vertex AI streaming

* refactor: move endpoints services outside of /routes dir to /services/Endpoints

* refactor: shorten schemas import

* refactor: rename PALM to GOOGLE

* feat: make Google editable endpoint

* feat: reusable Ask and Edit controllers based off Anthropic

* chore: organize imports/logic

* fix(parseConvo): include examples in googleSchema

* fix: google only allows odd number of messages to be sent

* fix: pass proxy to AnthropicClient

* refactor: change `google` altName to `Google`

* refactor: update getModelMaxTokens and related functions to handle maxTokensMap with nested endpoint model key/values

* refactor: google Icon and response sender changes (Codey and Google logo instead of PaLM in all cases)

* feat: google support for maxTokensMap

* feat: google updated endpoints with Ask/Edit controllers, buildOptions, and initializeClient

* feat(GoogleClient): now builds prompt for text models and supports real streaming from Vertex AI through langchain

* chore(GoogleClient): remove comments, left before for reference in git history

* docs: update google instructions (WIP)

* docs(apis_and_tokens.md): add images to google instructions

* docs: remove typo apis_and_tokens.md

* Update apis_and_tokens.md

* feat(Google): use default settings map, fully support context for both text and chat models, fully support examples for chat models

* chore: update more PaLM references to Google

* chore: move playwright out of workflows to avoid failing tests
This commit is contained in:
Danny Avila 2023-12-10 14:54:13 -05:00 committed by GitHub
parent 8a1968b2f8
commit 583e978a82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
90 changed files with 1613 additions and 784 deletions

View file

@ -1,6 +1,6 @@
const Anthropic = require('@anthropic-ai/sdk');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { getResponseSender, EModelEndpoint } = require('~/server/routes/endpoints/schemas');
const { getResponseSender, EModelEndpoint } = require('~/server/services/Endpoints');
const { getModelMaxTokens } = require('~/utils');
const BaseClient = require('./BaseClient');
@ -46,7 +46,8 @@ class AnthropicClient extends BaseClient {
stop: modelOptions.stop, // no stop method for now
};
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 100000;
this.maxContextTokens =
getModelMaxTokens(this.modelOptions.model, EModelEndpoint.anthropic) ?? 100000;
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1500;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;

View file

@ -445,6 +445,7 @@ class BaseClient {
amount: promptTokens,
debug: this.options.debug,
model: this.modelOptions.model,
endpoint: this.options.endpoint,
},
});
}

View file

@ -1,23 +1,43 @@
const BaseClient = require('./BaseClient');
const { google } = require('googleapis');
const { Agent, ProxyAgent } = require('undici');
const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
getResponseSender,
EModelEndpoint,
endpointSettings,
} = require('~/server/services/Endpoints');
const { getModelMaxTokens } = require('~/utils');
const { formatMessage } = require('./prompts');
const BaseClient = require('./BaseClient');
const loc = 'us-central1';
const publisher = 'google';
const endpointPrefix = `https://${loc}-aiplatform.googleapis.com`;
// const apiEndpoint = loc + '-aiplatform.googleapis.com';
const tokenizersCache = {};
const settings = endpointSettings[EModelEndpoint.google];
class GoogleClient extends BaseClient {
constructor(credentials, options = {}) {
super('apiKey', options);
this.credentials = credentials;
this.client_email = credentials.client_email;
this.project_id = credentials.project_id;
this.private_key = credentials.private_key;
this.sender = 'PaLM2';
this.access_token = null;
if (options.skipSetOptions) {
return;
}
this.setOptions(options);
}
/* Google/PaLM2 specific methods */
/* Google specific methods */
constructUrl() {
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`;
return `${endpointPrefix}/v1/projects/${this.project_id}/locations/${loc}/publishers/${publisher}/models/${this.modelOptions.model}:serverStreamingPredict`;
}
async getClient() {
@ -35,6 +55,24 @@ class GoogleClient extends BaseClient {
return jwtClient;
}
async getAccessToken() {
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
return new Promise((resolve, reject) => {
jwtClient.authorize((err, tokens) => {
if (err) {
console.error('Error: jwtClient failed to authorize');
console.error(err.message);
reject(err);
} else {
console.log('Access Token:', tokens.access_token);
resolve(tokens.access_token);
}
});
});
}
/* Required Client methods */
setOptions(options) {
if (this.options && !this.options.replaceOptions) {
@ -53,30 +91,33 @@ class GoogleClient extends BaseClient {
this.options = options;
}
this.options.examples = this.options.examples.filter(
(obj) => obj.input.content !== '' && obj.output.content !== '',
);
this.options.examples = this.options.examples
.filter((ex) => ex)
.filter((obj) => obj.input.content !== '' && obj.output.content !== '');
const modelOptions = this.options.modelOptions || {};
this.modelOptions = {
...modelOptions,
// set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || 'chat-bison',
temperature: typeof modelOptions.temperature === 'undefined' ? 0.2 : modelOptions.temperature, // 0 - 1, 0.2 is recommended
topP: typeof modelOptions.topP === 'undefined' ? 0.95 : modelOptions.topP, // 0 - 1, default: 0.95
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
model: modelOptions.model || settings.model.default,
temperature:
typeof modelOptions.temperature === 'undefined'
? settings.temperature.default
: modelOptions.temperature,
topP: typeof modelOptions.topP === 'undefined' ? settings.topP.default : modelOptions.topP,
topK: typeof modelOptions.topK === 'undefined' ? settings.topK.default : modelOptions.topK,
// stop: modelOptions.stop // no stop method for now
};
this.isChatModel = this.modelOptions.model.startsWith('chat-');
this.isChatModel = this.modelOptions.model.includes('chat');
const { isChatModel } = this;
this.isTextModel = this.modelOptions.model.startsWith('text-');
this.isTextModel = !isChatModel && /code|text/.test(this.modelOptions.model);
const { isTextModel } = this;
this.maxContextTokens = this.options.maxContextTokens || (isTextModel ? 8000 : 4096);
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google);
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1024;
this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
@ -88,6 +129,14 @@ class GoogleClient extends BaseClient {
);
}
this.sender =
this.options.sender ??
getResponseSender({
model: this.modelOptions.model,
endpoint: EModelEndpoint.google,
modelLabel: this.options.modelLabel,
});
this.userLabel = this.options.userLabel || 'User';
this.modelLabel = this.options.modelLabel || 'Assistant';
@ -99,8 +148,8 @@ class GoogleClient extends BaseClient {
this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
} else if (isTextModel) {
this.startToken = '<|im_start|>';
this.endToken = '<|im_end|>';
this.startToken = '||>';
this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
@ -138,15 +187,18 @@ class GoogleClient extends BaseClient {
return this;
}
getMessageMapMethod() {
formatMessages() {
return ((message) => ({
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
content: message?.content ?? message.text,
})).bind(this);
}
buildMessages(messages = []) {
const formattedMessages = messages.map(this.getMessageMapMethod());
buildMessages(messages = [], parentMessageId) {
if (this.isTextModel) {
return this.buildMessagesPrompt(messages, parentMessageId);
}
const formattedMessages = messages.map(this.formatMessages());
let payload = {
instances: [
{
@ -164,15 +216,6 @@ class GoogleClient extends BaseClient {
payload.instances[0].examples = this.options.examples;
}
/* TO-DO: text model needs more context since it can't process an array of messages */
if (this.isTextModel) {
payload.instances = [
{
prompt: messages[messages.length - 1].content,
},
];
}
if (this.options.debug) {
console.debug('GoogleClient buildMessages');
console.dir(payload, { depth: null });
@ -181,7 +224,157 @@ class GoogleClient extends BaseClient {
return { prompt: payload };
}
async getCompletion(payload, abortController = null) {
async buildMessagesPrompt(messages, parentMessageId) {
const orderedMessages = this.constructor.getMessagesForConversation({
messages,
parentMessageId,
});
if (this.options.debug) {
console.debug('GoogleClient: orderedMessages', orderedMessages, parentMessageId);
}
const formattedMessages = orderedMessages.map((message) => ({
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
content: message?.content ?? message.text,
}));
let lastAuthor = '';
let groupedMessages = [];
for (let message of formattedMessages) {
// If last author is not same as current author, add to new group
if (lastAuthor !== message.author) {
groupedMessages.push({
author: message.author,
content: [message.content],
});
lastAuthor = message.author;
// If same author, append content to the last group
} else {
groupedMessages[groupedMessages.length - 1].content.push(message.content);
}
}
let identityPrefix = '';
if (this.options.userLabel) {
identityPrefix = `\nHuman's name: ${this.options.userLabel}`;
}
if (this.options.modelLabel) {
identityPrefix = `${identityPrefix}\nYou are ${this.options.modelLabel}`;
}
let promptPrefix = (this.options.promptPrefix || '').trim();
if (promptPrefix) {
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `\nContext:\n${promptPrefix}`;
}
if (identityPrefix) {
promptPrefix = `${identityPrefix}${promptPrefix}`;
}
// Prompt AI to respond, empty if last message was from AI
let isEdited = lastAuthor === this.modelLabel;
const promptSuffix = isEdited ? '' : `${promptPrefix}\n\n${this.modelLabel}:\n`;
let currentTokenCount = isEdited
? this.getTokenCount(promptPrefix)
: this.getTokenCount(promptSuffix);
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
const context = [];
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
// Also, remove the next message when the message that puts us over the token limit is created by the user.
// Otherwise, remove only the exceeding message. This is due to Anthropic's strict payload rule to start with "Human:".
const nextMessage = {
remove: false,
tokenCount: 0,
messageString: '',
};
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && groupedMessages.length > 0) {
const message = groupedMessages.pop();
const isCreatedByUser = message.author === this.userLabel;
// Use promptPrefix if message is edited assistant'
const messagePrefix =
isCreatedByUser || !isEdited
? `\n\n${message.author}:`
: `${promptPrefix}\n\n${message.author}:`;
const messageString = `${messagePrefix}\n${message.content}${this.endToken}\n`;
let newPromptBody = `${messageString}${promptBody}`;
context.unshift(message);
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (!isCreatedByUser) {
nextMessage.messageString = messageString;
nextMessage.tokenCount = tokenCountForMessage;
}
if (newTokenCount > maxTokenCount) {
if (!promptBody) {
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
// Otherwise, ths message would put us over the token limit, so don't add it.
// if created by user, remove next message, otherwise remove only this message
if (isCreatedByUser) {
nextMessage.remove = true;
}
return false;
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// Switch off isEdited after using it for the first time
if (isEdited) {
isEdited = false;
}
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
if (nextMessage.remove) {
promptBody = promptBody.replace(nextMessage.messageString, '');
currentTokenCount -= nextMessage.tokenCount;
context.shift();
}
let prompt = `${promptBody}${promptSuffix}`;
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.maxOutputTokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
return { prompt, context };
}
async _getCompletion(payload, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
@ -212,6 +405,72 @@ class GoogleClient extends BaseClient {
return res.data;
}
async getCompletion(_payload, options = {}) {
const { onProgress, abortController } = options;
const { parameters, instances } = _payload;
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
let examples;
let clientOptions = {
authOptions: {
credentials: {
...this.credentials,
},
projectId: this.project_id,
},
...parameters,
};
if (!parameters) {
clientOptions = { ...clientOptions, ...this.modelOptions };
}
if (_examples && _examples.length) {
examples = _examples
.map((ex) => {
const { input, output } = ex;
if (!input || !output) {
return undefined;
}
return {
input: new HumanMessage(input.content),
output: new AIMessage(output.content),
};
})
.filter((ex) => ex);
clientOptions.examples = examples;
}
const model = this.isTextModel
? new GoogleVertexAI(clientOptions)
: new ChatGoogleVertexAI(clientOptions);
let reply = '';
const messages = this.isTextModel
? _payload.trim()
: _messages
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
.map((message) => formatMessage({ message, langChain: true }));
if (context && messages?.length > 0) {
messages.unshift(new SystemMessage(context));
}
const stream = await model.stream(messages, {
signal: abortController.signal,
timeout: 7000,
});
for await (const chunk of stream) {
await this.generateTextStream(chunk?.content ?? chunk, onProgress, { delay: 7 });
reply += chunk?.content ?? chunk;
}
return reply;
}
getSaveOptions() {
return {
promptPrefix: this.options.promptPrefix,
@ -225,34 +484,18 @@ class GoogleClient extends BaseClient {
}
async sendCompletion(payload, opts = {}) {
console.log('GoogleClient: sendcompletion', payload, opts);
let reply = '';
let blocked = false;
try {
const result = await this.getCompletion(payload, opts.abortController);
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked;
reply =
result?.predictions?.[0]?.candidates?.[0]?.content ||
result?.predictions?.[0]?.content ||
'';
if (blocked === true) {
reply = `Google blocked a proper response to your message:\n${JSON.stringify(
result.predictions[0].safetyAttributes,
)}${reply.length > 0 ? `\nAI Response:\n${reply}` : ''}`;
}
reply = await this.getCompletion(payload, opts);
if (this.options.debug) {
console.debug('result');
console.debug(result);
console.debug(reply);
}
} catch (err) {
console.error('Error: failed to send completion to Google');
console.error(err);
console.error(err.message);
}
if (!blocked) {
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 });
}
return reply.trim();
}

View file

@ -1,10 +1,10 @@
const OpenAI = require('openai');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { getResponseSender, EModelEndpoint } = require('~/server/services/Endpoints');
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('~/utils');
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
const { getResponseSender, EModelEndpoint } = require('~/server/routes/endpoints/schemas');
const { handleOpenAIErrors } = require('./tools/util');
const spendTokens = require('~/models/spendTokens');
const { createLLM, RunManager } = require('./llm');

View file

@ -3,11 +3,12 @@ const { CallbackManager } = require('langchain/callbacks');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
const checkBalance = require('../../models/checkBalance');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const { formatLangChainMessages } = require('./prompts');
const { isEnabled } = require('../../server/utils');
const { extractBaseURL } = require('../../utils');
const checkBalance = require('~/models/checkBalance');
const { SelfReflectionTool } = require('./tools');
const { isEnabled } = require('~/server/utils');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
class PluginsClient extends OpenAIClient {
@ -304,6 +305,7 @@ class PluginsClient extends OpenAIClient {
amount: promptTokens,
debug: this.options.debug,
model: this.modelOptions.model,
endpoint: EModelEndpoint.openAI,
},
});
}

View file

@ -1,7 +1,8 @@
const { promptTokensEstimate } = require('openai-chat-tokens');
const checkBalance = require('../../../models/checkBalance');
const { isEnabled } = require('../../../server/utils');
const { formatFromLangChain } = require('../prompts');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const { formatFromLangChain } = require('~/app/clients/prompts');
const checkBalance = require('~/models/checkBalance');
const { isEnabled } = require('~/server/utils');
const createStartHandler = ({
context,
@ -55,6 +56,7 @@ const createStartHandler = ({
debug: manager.debug,
generations,
model,
endpoint: EModelEndpoint.openAI,
},
});
}

View file

@ -0,0 +1,42 @@
/**
* Formats an object to match the struct_val, list_val, string_val, float_val, and int_val format.
*
* @param {Object} obj - The object to be formatted.
* @returns {Object} The formatted object.
*
* Handles different types:
* - Arrays are wrapped in list_val and each element is processed.
* - Objects are recursively processed.
* - Strings are wrapped in string_val.
* - Numbers are wrapped in float_val or int_val depending on whether they are floating-point or integers.
*/
function formatGoogleInputs(obj) {
const formattedObj = {};
for (const key in obj) {
if (Object.prototype.hasOwnProperty.call(obj, key)) {
const value = obj[key];
// Handle arrays
if (Array.isArray(value)) {
formattedObj[key] = { list_val: value.map((item) => formatGoogleInputs(item)) };
}
// Handle objects
else if (typeof value === 'object' && value !== null) {
formattedObj[key] = formatGoogleInputs(value);
}
// Handle numbers
else if (typeof value === 'number') {
formattedObj[key] = Number.isInteger(value) ? { int_val: value } : { float_val: value };
}
// Handle other types (e.g., strings)
else {
formattedObj[key] = { string_val: [value] };
}
}
}
return { struct_val: formattedObj };
}
module.exports = formatGoogleInputs;

View file

@ -0,0 +1,274 @@
const formatGoogleInputs = require('./formatGoogleInputs');
describe('formatGoogleInputs', () => {
it('formats message correctly', () => {
const input = {
messages: [
{
content: 'hi',
author: 'user',
},
],
context: 'context',
examples: [
{
input: {
author: 'user',
content: 'user input',
},
output: {
author: 'bot',
content: 'bot output',
},
},
],
parameters: {
temperature: 0.2,
topP: 0.8,
topK: 40,
maxOutputTokens: 1024,
},
};
const expectedOutput = {
struct_val: {
messages: {
list_val: [
{
struct_val: {
content: {
string_val: ['hi'],
},
author: {
string_val: ['user'],
},
},
},
],
},
context: {
string_val: ['context'],
},
examples: {
list_val: [
{
struct_val: {
input: {
struct_val: {
author: {
string_val: ['user'],
},
content: {
string_val: ['user input'],
},
},
},
output: {
struct_val: {
author: {
string_val: ['bot'],
},
content: {
string_val: ['bot output'],
},
},
},
},
},
],
},
parameters: {
struct_val: {
temperature: {
float_val: 0.2,
},
topP: {
float_val: 0.8,
},
topK: {
int_val: 40,
},
maxOutputTokens: {
int_val: 1024,
},
},
},
},
};
const result = formatGoogleInputs(input);
expect(JSON.stringify(result)).toEqual(JSON.stringify(expectedOutput));
});
it('formats real payload parts', () => {
const input = {
instances: [
{
context: 'context',
examples: [
{
input: {
author: 'user',
content: 'user input',
},
output: {
author: 'bot',
content: 'user output',
},
},
],
messages: [
{
author: 'user',
content: 'hi',
},
],
},
],
parameters: {
candidateCount: 1,
maxOutputTokens: 1024,
temperature: 0.2,
topP: 0.8,
topK: 40,
},
};
const expectedOutput = {
struct_val: {
instances: {
list_val: [
{
struct_val: {
context: { string_val: ['context'] },
examples: {
list_val: [
{
struct_val: {
input: {
struct_val: {
author: { string_val: ['user'] },
content: { string_val: ['user input'] },
},
},
output: {
struct_val: {
author: { string_val: ['bot'] },
content: { string_val: ['user output'] },
},
},
},
},
],
},
messages: {
list_val: [
{
struct_val: {
author: { string_val: ['user'] },
content: { string_val: ['hi'] },
},
},
],
},
},
},
],
},
parameters: {
struct_val: {
candidateCount: { int_val: 1 },
maxOutputTokens: { int_val: 1024 },
temperature: { float_val: 0.2 },
topP: { float_val: 0.8 },
topK: { int_val: 40 },
},
},
},
};
const result = formatGoogleInputs(input);
expect(JSON.stringify(result)).toEqual(JSON.stringify(expectedOutput));
});
it('helps create valid payload parts', () => {
const instances = {
context: 'context',
examples: [
{
input: {
author: 'user',
content: 'user input',
},
output: {
author: 'bot',
content: 'user output',
},
},
],
messages: [
{
author: 'user',
content: 'hi',
},
],
};
const expectedInstances = {
struct_val: {
context: { string_val: ['context'] },
examples: {
list_val: [
{
struct_val: {
input: {
struct_val: {
author: { string_val: ['user'] },
content: { string_val: ['user input'] },
},
},
output: {
struct_val: {
author: { string_val: ['bot'] },
content: { string_val: ['user output'] },
},
},
},
},
],
},
messages: {
list_val: [
{
struct_val: {
author: { string_val: ['user'] },
content: { string_val: ['hi'] },
},
},
],
},
},
};
const parameters = {
candidateCount: 1,
maxOutputTokens: 1024,
temperature: 0.2,
topP: 0.8,
topK: 40,
};
const expectedParameters = {
struct_val: {
candidateCount: { int_val: 1 },
maxOutputTokens: { int_val: 1024 },
temperature: { float_val: 0.2 },
topP: { float_val: 0.8 },
topK: { int_val: 40 },
},
};
const instancesResult = formatGoogleInputs(instances);
const parametersResult = formatGoogleInputs(parameters);
expect(JSON.stringify(instancesResult)).toEqual(JSON.stringify(expectedInstances));
expect(JSON.stringify(parametersResult)).toEqual(JSON.stringify(expectedParameters));
});
});

View file

@ -2,8 +2,16 @@ const mongoose = require('mongoose');
const balanceSchema = require('./schema/balance');
const { getMultiplier } = require('./tx');
balanceSchema.statics.check = async function ({ user, model, valueKey, tokenType, amount, debug }) {
const multiplier = getMultiplier({ valueKey, tokenType, model });
balanceSchema.statics.check = async function ({
user,
model,
endpoint,
valueKey,
tokenType,
amount,
debug,
}) {
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint });
const tokenCost = amount * multiplier;
const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {};
@ -11,6 +19,7 @@ balanceSchema.statics.check = async function ({ user, model, valueKey, tokenType
console.log('balance check', {
user,
model,
endpoint,
valueKey,
tokenType,
amount,

View file

@ -18,10 +18,11 @@ const tokenValues = {
* Retrieves the key associated with a given model name.
*
* @param {string} model - The model name to match.
* @param {string} endpoint - The endpoint name to match.
* @returns {string|undefined} The key corresponding to the model name, or undefined if no match is found.
*/
const getValueKey = (model) => {
const modelName = matchModelName(model);
const getValueKey = (model, endpoint) => {
const modelName = matchModelName(model, endpoint);
if (!modelName) {
return undefined;
}
@ -51,9 +52,10 @@ const getValueKey = (model) => {
* @param {string} [params.valueKey] - The key corresponding to the model name.
* @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion').
* @param {string} [params.model] - The model name to derive the value key from if not provided.
* @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided.
* @returns {number} The multiplier for the given parameters, or a default value if not found.
*/
const getMultiplier = ({ valueKey, tokenType, model }) => {
const getMultiplier = ({ valueKey, tokenType, model, endpoint }) => {
if (valueKey && tokenType) {
return tokenValues[valueKey][tokenType] ?? defaultRate;
}
@ -62,7 +64,7 @@ const getMultiplier = ({ valueKey, tokenType, model }) => {
return 1;
}
valueKey = getValueKey(model);
valueKey = getValueKey(model, endpoint);
if (!valueKey) {
return defaultRate;
}

View file

@ -0,0 +1,132 @@
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { getResponseSender } = require('~/server/services/Endpoints');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const AskController = async (req, res, next, initializeClient) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let promptTokens;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
const user = req.user.id;
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const addMetadata = (data) => (metadata = data);
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
// debug: true,
user,
conversationId,
parentMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
onStart,
getReqData,
addMetadata,
abortController,
});
if (metadata) {
response = { ...response, ...metadata };
}
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
await saveMessage({ ...response, user });
await saveMessage(userMessage);
// TODO: add title service
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
};
module.exports = AskController;

View file

@ -0,0 +1,135 @@
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { getResponseSender } = require('~/server/services/Endpoints');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const EditController = async (req, res, next, initializeClient) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let promptTokens;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
const userMessageId = parentMessageId;
const user = req.user.id;
const addMetadata = (data) => (metadata = data);
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
getReqData,
onStart,
addMetadata,
abortController,
});
if (metadata) {
response = { ...response, ...metadata };
}
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
await saveMessage({ ...response, user });
await saveMessage(userMessage);
// TODO: add title service
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
};
module.exports = EditController;

View file

@ -1,14 +1,16 @@
const openAI = require('~/server/routes/endpoints/openAI');
const gptPlugins = require('~/server/routes/endpoints/gptPlugins');
const anthropic = require('~/server/routes/endpoints/anthropic');
const { parseConvo, EModelEndpoint } = require('~/server/routes/endpoints/schemas');
const { processFiles } = require('~/server/services/Files');
const openAI = require('~/server/services/Endpoints/openAI');
const google = require('~/server/services/Endpoints/google');
const anthropic = require('~/server/services/Endpoints/anthropic');
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
const { parseConvo, EModelEndpoint } = require('~/server/services/Endpoints');
const buildFunction = {
[EModelEndpoint.openAI]: openAI.buildOptions,
[EModelEndpoint.google]: google.buildOptions,
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
[EModelEndpoint.anthropic]: anthropic.buildOptions,
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
};
function buildEndpointOption(req, res, next) {

View file

@ -1,7 +1,7 @@
const crypto = require('crypto');
const { sendMessage, sendError } = require('../utils');
const { getResponseSender } = require('../routes/endpoints/schemas');
const { saveMessage } = require('../../models');
const { saveMessage } = require('~/models');
const { sendMessage, sendError } = require('~/server/utils');
const { getResponseSender } = require('~/server/services/Endpoints');
/**
* Denies a request by sending an error message and optionally saves the user's message.

View file

@ -1,137 +1,19 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/anthropic');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
handleAbort,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const router = express.Router();
router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let userMessage;
let promptTokens;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
const user = req.user.id;
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
getReqData,
// debug: true,
user,
conversationId,
parentMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
onStart,
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
await saveMessage({ ...response, user });
await saveMessage(userMessage);
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await AskController(req, res, next, initializeClient);
});
module.exports = router;

View file

@ -1,181 +1,19 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
const crypto = require('crypto');
const { GoogleClient } = require('../../../app');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress } = require('../../utils');
const { getUserKey, checkUserKeyExpiry } = require('../../services/UserService');
const { setHeaders } = require('../../middleware');
router.post('/', setHeaders, async (req, res) => {
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
if (text.length === 0) {
return handleError(res, { text: 'Prompt empty or too short' });
}
if (endpoint !== 'google') {
return handleError(res, { text: 'Illegal request' });
}
router.post('/abort', handleAbort());
// build endpoint option
const endpointOption = {
examples: req.body?.examples ?? [{ input: { content: '' }, output: { content: '' } }],
promptPrefix: req.body?.promptPrefix ?? null,
key: req.body?.key ?? null,
modelOptions: {
model: req.body?.model ?? 'chat-bison',
modelLabel: req.body?.modelLabel ?? null,
temperature: req.body?.temperature ?? 0.2,
maxOutputTokens: req.body?.maxOutputTokens ?? 1024,
topP: req.body?.topP ?? 0.95,
topK: req.body?.topK ?? 40,
},
};
const availableModels = ['chat-bison', 'text-bison', 'codechat-bison'];
if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) {
return handleError(res, { text: 'Illegal request: model' });
}
const conversationId = oldConversationId || crypto.randomUUID();
// eslint-disable-next-line no-use-before-define
return await ask({
text,
endpointOption,
conversationId,
parentMessageId,
req,
res,
});
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await AskController(req, res, next, initializeClient);
});
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
let userMessage;
let userMessageId;
// let promptTokens;
let responseMessageId;
let lastSavedTimestamp = 0;
const { overrideParentMessageId = null } = req.body;
const user = req.user.id;
try {
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
// } else if (key === 'promptTokens') {
// promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
sendMessage(res, { message: userMessage, created: true });
};
const { onProgress: progressCallback } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: 'PaLM2',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
user,
});
}
},
});
const abortController = new AbortController();
const isUserProvided = process.env.PALM_KEY === 'user_provided';
let key;
if (endpointOption.key && isUserProvided) {
checkUserKeyExpiry(
endpointOption.key,
'Your GOOGLE_TOKEN has expired. Please provide your token again.',
);
key = await getUserKey({ userId: user, name: 'google' });
key = JSON.parse(key);
delete endpointOption.key;
console.log('Using service account key provided by User for PaLM models');
}
try {
key = require('../../../data/auth.json');
} catch (e) {
console.log('No \'auth.json\' file (service account key) found in /api/data/ for PaLM models');
}
const clientOptions = {
// debug: true, // for testing
reverseProxyUrl: process.env.GOOGLE_REVERSE_PROXY || null,
proxy: process.env.PROXY || null,
...endpointOption,
};
const client = new GoogleClient(key, clientOptions);
let response = await client.sendMessage(text, {
getReqData,
user,
conversationId,
parentMessageId,
overrideParentMessageId,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
await saveConvo(user, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'google',
});
await saveMessage({ ...response, user });
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
console.error(error);
const errorMessage = {
messageId: responseMessageId,
sender: 'PaLM2',
conversationId,
parentMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message,
};
await saveMessage({ ...errorMessage, user });
handleError(res, errorMessage);
}
};
module.exports = router;

View file

@ -1,11 +1,11 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { validateTools } = require('../../../app');
const { addTitle } = require('../endpoints/openAI');
const { initializeClient } = require('../endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
const { getResponseSender } = require('~/server/services/Endpoints');
const { validateTools } = require('~/app');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const {
handleAbort,
createAbortController,
@ -13,7 +13,7 @@ const {
setHeaders,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
} = require('~/server/middleware');
router.post('/abort', handleAbort());

View file

@ -1,11 +1,12 @@
const express = require('express');
const router = express.Router();
const openAI = require('./openAI');
const google = require('./google');
const bingAI = require('./bingAI');
const anthropic = require('./anthropic');
const gptPlugins = require('./gptPlugins');
const askChatGPTBrowser = require('./askChatGPTBrowser');
const anthropic = require('./anthropic');
const { isEnabled } = require('~/server/utils');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const {
uaParser,
checkBan,
@ -13,12 +14,12 @@ const {
concurrentLimiter,
messageIpLimiter,
messageUserLimiter,
} = require('../../middleware');
const { isEnabled } = require('../../utils');
const { EModelEndpoint } = require('../endpoints/schemas');
} = require('~/server/middleware');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
@ -36,10 +37,10 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
}
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.bingAI}`, bingAI);
router.use(`/${EModelEndpoint.chatGPTBrowser}`, askChatGPTBrowser);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.bingAI}`, bingAI);
module.exports = router;

View file

@ -2,8 +2,8 @@ const express = require('express');
const router = express.Router();
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { getResponseSender } = require('~/server/routes/endpoints/schemas');
const { addTitle, initializeClient } = require('~/server/routes/endpoints/openAI');
const { getResponseSender } = require('~/server/services/Endpoints');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI');
const {
handleAbort,
createAbortController,

View file

@ -1,147 +1,19 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/anthropic');
const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
handleAbort,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const router = express.Router();
router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let promptTokens;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
const userMessageId = parentMessageId;
const user = req.user.id;
const addMetadata = (data) => (metadata = data);
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
getReqData,
onStart,
addMetadata,
abortController,
});
if (metadata) {
response = { ...response, ...metadata };
}
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
await saveMessage({ ...response, user });
await saveMessage(userMessage);
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await EditController(req, res, next, initializeClient);
});
module.exports = router;

View file

@ -0,0 +1,19 @@
const express = require('express');
const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await EditController(req, res, next, initializeClient);
});
module.exports = router;

View file

@ -1,10 +1,10 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { validateTools } = require('../../../app');
const { initializeClient } = require('../endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils');
const { validateTools } = require('~/app');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { getResponseSender } = require('~/server/services/Endpoints');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const {
handleAbort,
createAbortController,
@ -12,7 +12,7 @@ const {
setHeaders,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
} = require('~/server/middleware');
router.post('/abort', handleAbort());

View file

@ -1,20 +1,23 @@
const express = require('express');
const router = express.Router();
const openAI = require('./openAI');
const gptPlugins = require('./gptPlugins');
const google = require('./google');
const anthropic = require('./anthropic');
const gptPlugins = require('./gptPlugins');
const { isEnabled } = require('~/server/utils');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const {
checkBan,
uaParser,
requireJwtAuth,
concurrentLimiter,
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
} = require('../../middleware');
const { isEnabled } = require('../../utils');
} = require('~/server/middleware');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
@ -31,8 +34,9 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use(['/azureOpenAI', '/openAI'], openAI);
router.use('/gptPlugins', gptPlugins);
router.use('/anthropic', anthropic);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
module.exports = router;

View file

@ -1,9 +1,9 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/openAI');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
const { getResponseSender } = require('~/server/services/Endpoints');
const { initializeClient } = require('~/server/services/Endpoints/openAI');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const {
handleAbort,
createAbortController,
@ -11,7 +11,7 @@ const {
setHeaders,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
} = require('~/server/middleware');
router.post('/abort', handleAbort());

View file

@ -1,4 +1,4 @@
const { EModelEndpoint } = require('~/server/routes/endpoints/schemas');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const {
OPENAI_API_KEY: openAIApiKey,
@ -7,7 +7,7 @@ const {
CHATGPT_TOKEN: chatGPTToken,
BINGAI_TOKEN: bingToken,
PLUGINS_USE_AZURE,
PALM_KEY: palmKey,
GOOGLE_KEY: googleKey,
} = process.env ?? {};
const useAzurePlugins = !!PLUGINS_USE_AZURE;
@ -26,7 +26,7 @@ module.exports = {
azureOpenAIApiKey,
useAzurePlugins,
userProvidedOpenAI,
palmKey,
googleKey,
[EModelEndpoint.openAI]: isUserProvided(openAIApiKey),
[EModelEndpoint.assistant]: isUserProvided(openAIApiKey),
[EModelEndpoint.azureOpenAI]: isUserProvided(azureOpenAIApiKey),

View file

@ -1,6 +1,6 @@
const { availableTools } = require('~/app/clients/tools');
const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs');
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, palmKey } =
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } =
require('./EndpointService').config;
/**
@ -8,7 +8,7 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, pa
*/
async function loadAsyncEndpoints() {
let i = 0;
let key, palmUser;
let key, googleUserProvides;
try {
key = require('~/data/auth.json');
} catch (e) {
@ -17,8 +17,8 @@ async function loadAsyncEndpoints() {
}
}
if (palmKey === 'user_provided') {
palmUser = true;
if (googleKey === 'user_provided') {
googleUserProvides = true;
if (i <= 1) {
i++;
}
@ -33,7 +33,7 @@ async function loadAsyncEndpoints() {
}
const plugins = transformToolsToMap(tools);
const google = key || palmUser ? { userProvide: palmUser } : false;
const google = key || googleUserProvides ? { userProvide: googleUserProvides } : false;
const gptPlugins =
openAIApiKey || azureOpenAIApiKey

View file

@ -1,4 +1,4 @@
const { EModelEndpoint } = require('~/server/routes/endpoints/schemas');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const loadAsyncEndpoints = require('./loadAsyncEndpoints');
const { config } = require('./EndpointService');

View file

@ -3,7 +3,7 @@ const {
getChatGPTBrowserModels,
getAnthropicModels,
} = require('~/server/services/ModelService');
const { EModelEndpoint } = require('~/server/routes/endpoints/schemas');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
const fitlerAssistantModels = (str) => {
@ -21,7 +21,18 @@ async function loadDefaultModels() {
[EModelEndpoint.openAI]: openAI,
[EModelEndpoint.azureOpenAI]: azureOpenAI,
[EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels),
[EModelEndpoint.google]: ['chat-bison', 'text-bison', 'codechat-bison'],
[EModelEndpoint.google]: [
'chat-bison',
'chat-bison-32k',
'codechat-bison',
'codechat-bison-32k',
'text-bison',
'text-bison-32k',
'text-unicorn',
'code-gecko',
'code-bison',
'code-bison-32k',
],
[EModelEndpoint.bingAI]: ['BingAI', 'Sydney'],
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
[EModelEndpoint.gptPlugins]: gptPlugins,

View file

@ -2,7 +2,7 @@ const { AnthropicClient } = require('~/app');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const initializeClient = async ({ req, res, endpointOption }) => {
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY } = process.env;
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY, PROXY } = process.env;
const expiresAt = req.body.key;
const isUserProvided = ANTHROPIC_API_KEY === 'user_provided';
@ -21,6 +21,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
req,
res,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
...endpointOption,
});

View file

@ -0,0 +1,16 @@
const buildOptions = (endpoint, parsedBody) => {
const { examples, modelLabel, promptPrefix, ...rest } = parsedBody;
const endpointOption = {
examples,
endpoint,
modelLabel,
promptPrefix,
modelOptions: {
...rest,
},
};
return endpointOption;
};
module.exports = buildOptions;

View file

@ -0,0 +1,8 @@
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');
module.exports = {
// addTitle, // todo
buildOptions,
initializeClient,
};

View file

@ -0,0 +1,35 @@
const { GoogleClient } = require('~/app');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const initializeClient = async ({ req, res, endpointOption }) => {
const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, PROXY } = process.env;
const isUserProvided = GOOGLE_KEY === 'user_provided';
const { key: expiresAt } = req.body;
let userKey = null;
if (expiresAt && isUserProvided) {
checkUserKeyExpiry(
expiresAt,
'Your Google key has expired. Please provide your JSON credentials again.',
);
userKey = await getUserKey({ userId: req.user.id, name: EModelEndpoint.google });
}
const apiKey = isUserProvided ? userKey : require('~/data/auth.json');
const client = new GoogleClient(apiKey, {
req,
res,
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
...endpointOption,
});
return {
client,
apiKey,
};
};
module.exports = initializeClient;

View file

@ -1,7 +1,7 @@
const { PluginsClient } = require('../../../../app');
const { isEnabled } = require('../../../utils');
const { getAzureCredentials } = require('../../../../utils');
const { getUserKey, checkUserKeyExpiry } = require('../../../services/UserService');
const { PluginsClient } = require('~/app');
const { isEnabled } = require('~/server/utils');
const { getAzureCredentials } = require('~/utils');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const initializeClient = async ({ req, res, endpointOption }) => {
const {

View file

@ -1,12 +1,12 @@
// gptPlugins/initializeClient.spec.js
const { PluginsClient } = require('~/app');
const initializeClient = require('./initializeClient');
const { PluginsClient } = require('../../../../app');
const { getUserKey } = require('../../../services/UserService');
const { getUserKey } = require('../../UserService');
// Mock getUserKey since it's the only function we want to mock
jest.mock('../../../services/UserService', () => ({
jest.mock('~/server/services/UserService', () => ({
getUserKey: jest.fn(),
checkUserKeyExpiry: jest.requireActual('../../../services/UserService').checkUserKeyExpiry,
checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry,
}));
describe('gptPlugins/initializeClient', () => {

View file

@ -0,0 +1,5 @@
const schemas = require('./schemas');
module.exports = {
...schemas,
};

View file

@ -1,11 +1,11 @@
const { OpenAIClient } = require('~/app');
const initializeClient = require('./initializeClient');
const { OpenAIClient } = require('../../../../app');
const { getUserKey } = require('../../../services/UserService');
const { getUserKey } = require('~/server/services/UserService');
// Mock getUserKey since it's the only function we want to mock
jest.mock('../../../services/UserService', () => ({
jest.mock('~/server/services/UserService', () => ({
getUserKey: jest.fn(),
checkUserKeyExpiry: jest.requireActual('../../../services/UserService').checkUserKeyExpiry,
checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry,
}));
describe('initializeClient', () => {

View file

@ -18,10 +18,44 @@ const alternateName = {
[EModelEndpoint.bingAI]: 'Bing',
[EModelEndpoint.chatGPTBrowser]: 'ChatGPT',
[EModelEndpoint.gptPlugins]: 'Plugins',
[EModelEndpoint.google]: 'PaLM',
[EModelEndpoint.google]: 'Google',
[EModelEndpoint.anthropic]: 'Anthropic',
};
const endpointSettings = {
[EModelEndpoint.google]: {
model: {
default: 'chat-bison',
},
maxOutputTokens: {
min: 1,
max: 2048,
step: 1,
default: 1024,
},
temperature: {
min: 0,
max: 1,
step: 0.01,
default: 0.2,
},
topP: {
min: 0,
max: 1,
step: 0.01,
default: 0.8,
},
topK: {
min: 1,
max: 40,
step: 0.01,
default: 40,
},
},
};
const google = endpointSettings[EModelEndpoint.google];
const supportsFiles = {
[EModelEndpoint.openAI]: true,
[EModelEndpoint.assistant]: true,
@ -158,22 +192,24 @@ const googleSchema = tConversationSchema
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'chat-bison',
model: obj.model ?? google.model.default,
modelLabel: obj.modelLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 0.2,
maxOutputTokens: obj.maxOutputTokens ?? 1024,
topP: obj.topP ?? 0.95,
topK: obj.topK ?? 40,
examples: obj.examples ?? [{ input: { content: '' }, output: { content: '' } }],
temperature: obj.temperature ?? google.temperature.default,
maxOutputTokens: obj.maxOutputTokens ?? google.maxOutputTokens.default,
topP: obj.topP ?? google.topP.default,
topK: obj.topK ?? google.topK.default,
}))
.catch(() => ({
model: 'chat-bison',
model: google.model.default,
modelLabel: null,
promptPrefix: null,
temperature: 0.2,
maxOutputTokens: 1024,
topP: 0.95,
topK: 40,
examples: [{ input: { content: '' }, output: { content: '' } }],
temperature: google.temperature.default,
maxOutputTokens: google.maxOutputTokens.default,
topP: google.topP.default,
topK: google.topK.default,
}));
const bingAISchema = tConversationSchema
@ -385,7 +421,13 @@ const getResponseSender = (endpointOption) => {
}
if (endpoint === EModelEndpoint.google) {
return modelLabel ?? 'PaLM2';
if (modelLabel) {
return modelLabel;
} else if (model && model.includes('code')) {
return 'Codey';
}
return 'PaLM2';
}
return '';
@ -399,4 +441,5 @@ module.exports = {
openAIModels,
visionModels,
alternateName,
endpointSettings,
};

View file

@ -1,4 +1,4 @@
const { visionModels } = require('~/server/routes/endpoints/schemas');
const { visionModels } = require('~/server/services/Endpoints');
function validateVisionModel(model) {
if (!model) {

View file

@ -247,7 +247,7 @@
* @property {string} azureOpenAIApiKey - The API key for Azure OpenAI.
* @property {boolean} useAzurePlugins - Flag to indicate if Azure plugins are used.
* @property {boolean} userProvidedOpenAI - Flag to indicate if OpenAI API key is user provided.
* @property {string} palmKey - The Palm key.
* @property {string} googleKey - The Palm key.
* @property {boolean|{userProvide: boolean}} [openAI] - Flag to indicate if OpenAI endpoint is user provided, or its configuration.
* @property {boolean|{userProvide: boolean}} [assistant] - Flag to indicate if Assistant endpoint is user provided, or its configuration.
* @property {boolean|{userProvide: boolean}} [azureOpenAI] - Flag to indicate if Azure OpenAI endpoint is user provided, or its configuration.

View file

@ -1,3 +1,5 @@
const { EModelEndpoint } = require('~/server/services/Endpoints');
const models = [
'text-davinci-003',
'text-davinci-002',
@ -39,20 +41,37 @@ const models = [
// Order is important here: by model series and context size (gpt-4 then gpt-3, ascending)
const maxTokensMap = {
'gpt-4': 8191,
'gpt-4-0613': 8191,
'gpt-4-32k': 32767,
'gpt-4-32k-0314': 32767,
'gpt-4-32k-0613': 32767,
'gpt-3.5-turbo': 4095,
'gpt-3.5-turbo-0613': 4095,
'gpt-3.5-turbo-0301': 4095,
'gpt-3.5-turbo-16k': 15999,
'gpt-3.5-turbo-16k-0613': 15999,
'gpt-3.5-turbo-1106': 16380, // -5 from max
'gpt-4-1106': 127995, // -5 from max
'claude-2.1': 200000,
'claude-': 100000,
[EModelEndpoint.openAI]: {
'gpt-4': 8191,
'gpt-4-0613': 8191,
'gpt-4-32k': 32767,
'gpt-4-32k-0314': 32767,
'gpt-4-32k-0613': 32767,
'gpt-3.5-turbo': 4095,
'gpt-3.5-turbo-0613': 4095,
'gpt-3.5-turbo-0301': 4095,
'gpt-3.5-turbo-16k': 15999,
'gpt-3.5-turbo-16k-0613': 15999,
'gpt-3.5-turbo-1106': 16380, // -5 from max
'gpt-4-1106': 127995, // -5 from max
},
[EModelEndpoint.google]: {
/* Max I/O is 32k combined, so -1000 to leave room for response */
'text-bison-32k': 31000,
'chat-bison-32k': 31000,
'code-bison-32k': 31000,
'codechat-bison-32k': 31000,
/* Codey, -5 from max: 6144 */
'code-': 6139,
'codechat-': 6139,
/* PaLM2, -5 from max: 8192 */
'text-': 8187,
'chat-': 8187,
},
[EModelEndpoint.anthropic]: {
'claude-2.1': 200000,
'claude-': 100000,
},
};
/**
@ -60,6 +79,7 @@ const maxTokensMap = {
* it searches for partial matches within the model name, checking keys in reverse order.
*
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found.
*
* @example
@ -67,19 +87,24 @@ const maxTokensMap = {
* getModelMaxTokens('gpt-4-32k-unknown'); // Returns 32767
* getModelMaxTokens('unknown-model'); // Returns undefined
*/
function getModelMaxTokens(modelName) {
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) {
if (typeof modelName !== 'string') {
return undefined;
}
if (maxTokensMap[modelName]) {
return maxTokensMap[modelName];
const tokensMap = maxTokensMap[endpoint];
if (!tokensMap) {
return undefined;
}
const keys = Object.keys(maxTokensMap);
if (tokensMap[modelName]) {
return tokensMap[modelName];
}
const keys = Object.keys(tokensMap);
for (let i = keys.length - 1; i >= 0; i--) {
if (modelName.includes(keys[i])) {
return maxTokensMap[keys[i]];
return tokensMap[keys[i]];
}
}
@ -91,6 +116,7 @@ function getModelMaxTokens(modelName) {
* it searches for partial matches within the model name, checking keys in reverse order.
*
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @returns {string|undefined} The model name key for the given model; returns input if no match is found and is string.
*
* @example
@ -98,16 +124,21 @@ function getModelMaxTokens(modelName) {
* matchModelName('gpt-4-32k-unknown'); // Returns 'gpt-4-32k'
* matchModelName('unknown-model'); // Returns undefined
*/
function matchModelName(modelName) {
function matchModelName(modelName, endpoint = EModelEndpoint.openAI) {
if (typeof modelName !== 'string') {
return undefined;
}
if (maxTokensMap[modelName]) {
const tokensMap = maxTokensMap[endpoint];
if (!tokensMap) {
return modelName;
}
const keys = Object.keys(maxTokensMap);
if (tokensMap[modelName]) {
return modelName;
}
const keys = Object.keys(tokensMap);
for (let i = keys.length - 1; i >= 0; i--) {
if (modelName.includes(keys[i])) {
return keys[i];

View file

@ -1,16 +1,23 @@
const { EModelEndpoint } = require('~/server/services/Endpoints');
const { getModelMaxTokens, matchModelName, maxTokensMap } = require('./tokens');
describe('getModelMaxTokens', () => {
test('should return correct tokens for exact match', () => {
expect(getModelMaxTokens('gpt-4-32k-0613')).toBe(maxTokensMap['gpt-4-32k-0613']);
expect(getModelMaxTokens('gpt-4-32k-0613')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-4-32k-0613'],
);
});
test('should return correct tokens for partial match', () => {
expect(getModelMaxTokens('gpt-4-32k-unknown')).toBe(maxTokensMap['gpt-4-32k']);
expect(getModelMaxTokens('gpt-4-32k-unknown')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-4-32k'],
);
});
test('should return correct tokens for partial match (OpenRouter)', () => {
expect(getModelMaxTokens('openai/gpt-4-32k')).toBe(maxTokensMap['gpt-4-32k']);
expect(getModelMaxTokens('openai/gpt-4-32k')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-4-32k'],
);
});
test('should return undefined for no match', () => {
@ -19,12 +26,14 @@ describe('getModelMaxTokens', () => {
test('should return correct tokens for another exact match', () => {
expect(getModelMaxTokens('gpt-3.5-turbo-16k-0613')).toBe(
maxTokensMap['gpt-3.5-turbo-16k-0613'],
maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-16k-0613'],
);
});
test('should return correct tokens for another partial match', () => {
expect(getModelMaxTokens('gpt-3.5-turbo-unknown')).toBe(maxTokensMap['gpt-3.5-turbo']);
expect(getModelMaxTokens('gpt-3.5-turbo-unknown')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo'],
);
});
test('should return undefined for undefined input', () => {
@ -41,26 +50,34 @@ describe('getModelMaxTokens', () => {
// 11/06 Update
test('should return correct tokens for gpt-3.5-turbo-1106 exact match', () => {
expect(getModelMaxTokens('gpt-3.5-turbo-1106')).toBe(maxTokensMap['gpt-3.5-turbo-1106']);
expect(getModelMaxTokens('gpt-3.5-turbo-1106')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-1106'],
);
});
test('should return correct tokens for gpt-4-1106 exact match', () => {
expect(getModelMaxTokens('gpt-4-1106')).toBe(maxTokensMap['gpt-4-1106']);
expect(getModelMaxTokens('gpt-4-1106')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-4-1106']);
});
test('should return correct tokens for gpt-3.5-turbo-1106 partial match', () => {
expect(getModelMaxTokens('something-/gpt-3.5-turbo-1106')).toBe(
maxTokensMap['gpt-3.5-turbo-1106'],
maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-1106'],
);
expect(getModelMaxTokens('gpt-3.5-turbo-1106/something-/')).toBe(
maxTokensMap['gpt-3.5-turbo-1106'],
maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-1106'],
);
});
test('should return correct tokens for gpt-4-1106 partial match', () => {
expect(getModelMaxTokens('gpt-4-1106/something')).toBe(maxTokensMap['gpt-4-1106']);
expect(getModelMaxTokens('gpt-4-1106-preview')).toBe(maxTokensMap['gpt-4-1106']);
expect(getModelMaxTokens('gpt-4-1106-vision-preview')).toBe(maxTokensMap['gpt-4-1106']);
expect(getModelMaxTokens('gpt-4-1106/something')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-4-1106'],
);
expect(getModelMaxTokens('gpt-4-1106-preview')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-4-1106'],
);
expect(getModelMaxTokens('gpt-4-1106-vision-preview')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-4-1106'],
);
});
test('should return correct tokens for Anthropic models', () => {
@ -74,13 +91,36 @@ describe('getModelMaxTokens', () => {
'claude-instant-1-100k',
];
const claude21MaxTokens = maxTokensMap['claude-2.1'];
const claudeMaxTokens = maxTokensMap['claude-'];
const claudeMaxTokens = maxTokensMap[EModelEndpoint.anthropic]['claude-'];
const claude21MaxTokens = maxTokensMap[EModelEndpoint.anthropic]['claude-2.1'];
models.forEach((model) => {
const expectedTokens = model === 'claude-2.1' ? claude21MaxTokens : claudeMaxTokens;
expect(getModelMaxTokens(model)).toEqual(expectedTokens);
expect(getModelMaxTokens(model, EModelEndpoint.anthropic)).toEqual(expectedTokens);
});
});
// Tests for Google models
test('should return correct tokens for exact match - Google models', () => {
expect(getModelMaxTokens('text-bison-32k', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['text-bison-32k'],
);
expect(getModelMaxTokens('codechat-bison-32k', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['codechat-bison-32k'],
);
});
test('should return undefined for no match - Google models', () => {
expect(getModelMaxTokens('unknown-google-model', EModelEndpoint.google)).toBeUndefined();
});
test('should return correct tokens for partial match - Google models', () => {
expect(getModelMaxTokens('code-', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['code-'],
);
expect(getModelMaxTokens('chat-', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['chat-'],
);
});
});
describe('matchModelName', () => {
@ -122,4 +162,21 @@ describe('matchModelName', () => {
expect(matchModelName('gpt-4-1106-preview')).toBe('gpt-4-1106');
expect(matchModelName('gpt-4-1106-vision-preview')).toBe('gpt-4-1106');
});
// Tests for Google models
it('should return the exact model name if it exists in maxTokensMap - Google models', () => {
expect(matchModelName('text-bison-32k', EModelEndpoint.google)).toBe('text-bison-32k');
expect(matchModelName('codechat-bison-32k', EModelEndpoint.google)).toBe('codechat-bison-32k');
});
it('should return the input model name if no match is found - Google models', () => {
expect(matchModelName('unknown-google-model', EModelEndpoint.google)).toBe(
'unknown-google-model',
);
});
it('should return the closest matching key for partial matches - Google models', () => {
expect(matchModelName('code-', EModelEndpoint.google)).toBe('code-');
expect(matchModelName('chat-', EModelEndpoint.google)).toBe('chat-');
});
});