feat(Google): Support all Text/Chat Models, Response streaming, PaLM -> Google 🤖 (#1316)

* feat: update PaLM icons

* feat: add additional google models

* POC: formatting inputs for Vertex AI streaming

* refactor: move endpoints services outside of /routes dir to /services/Endpoints

* refactor: shorten schemas import

* refactor: rename PALM to GOOGLE

* feat: make Google editable endpoint

* feat: reusable Ask and Edit controllers based off Anthropic

* chore: organize imports/logic

* fix(parseConvo): include examples in googleSchema

* fix: google only allows odd number of messages to be sent

* fix: pass proxy to AnthropicClient

* refactor: change `google` altName to `Google`

* refactor: update getModelMaxTokens and related functions to handle maxTokensMap with nested endpoint model key/values

* refactor: google Icon and response sender changes (Codey and Google logo instead of PaLM in all cases)

* feat: google support for maxTokensMap

* feat: google updated endpoints with Ask/Edit controllers, buildOptions, and initializeClient

* feat(GoogleClient): now builds prompt for text models and supports real streaming from Vertex AI through langchain

* chore(GoogleClient): remove comments, left before for reference in git history

* docs: update google instructions (WIP)

* docs(apis_and_tokens.md): add images to google instructions

* docs: remove typo apis_and_tokens.md

* Update apis_and_tokens.md

* feat(Google): use default settings map, fully support context for both text and chat models, fully support examples for chat models

* chore: update more PaLM references to Google

* chore: move playwright out of workflows to avoid failing tests
This commit is contained in:
Danny Avila 2023-12-10 14:54:13 -05:00 committed by GitHub
parent 8a1968b2f8
commit 583e978a82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
90 changed files with 1613 additions and 784 deletions

View file

@ -1,137 +1,19 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/anthropic');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
handleAbort,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const router = express.Router();
router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let userMessage;
let promptTokens;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
const user = req.user.id;
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
getReqData,
// debug: true,
user,
conversationId,
parentMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
onStart,
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
await saveMessage({ ...response, user });
await saveMessage(userMessage);
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await AskController(req, res, next, initializeClient);
});
module.exports = router;

View file

@ -1,181 +1,19 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
const crypto = require('crypto');
const { GoogleClient } = require('../../../app');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress } = require('../../utils');
const { getUserKey, checkUserKeyExpiry } = require('../../services/UserService');
const { setHeaders } = require('../../middleware');
router.post('/', setHeaders, async (req, res) => {
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
if (text.length === 0) {
return handleError(res, { text: 'Prompt empty or too short' });
}
if (endpoint !== 'google') {
return handleError(res, { text: 'Illegal request' });
}
router.post('/abort', handleAbort());
// build endpoint option
const endpointOption = {
examples: req.body?.examples ?? [{ input: { content: '' }, output: { content: '' } }],
promptPrefix: req.body?.promptPrefix ?? null,
key: req.body?.key ?? null,
modelOptions: {
model: req.body?.model ?? 'chat-bison',
modelLabel: req.body?.modelLabel ?? null,
temperature: req.body?.temperature ?? 0.2,
maxOutputTokens: req.body?.maxOutputTokens ?? 1024,
topP: req.body?.topP ?? 0.95,
topK: req.body?.topK ?? 40,
},
};
const availableModels = ['chat-bison', 'text-bison', 'codechat-bison'];
if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) {
return handleError(res, { text: 'Illegal request: model' });
}
const conversationId = oldConversationId || crypto.randomUUID();
// eslint-disable-next-line no-use-before-define
return await ask({
text,
endpointOption,
conversationId,
parentMessageId,
req,
res,
});
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await AskController(req, res, next, initializeClient);
});
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
let userMessage;
let userMessageId;
// let promptTokens;
let responseMessageId;
let lastSavedTimestamp = 0;
const { overrideParentMessageId = null } = req.body;
const user = req.user.id;
try {
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
// } else if (key === 'promptTokens') {
// promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
sendMessage(res, { message: userMessage, created: true });
};
const { onProgress: progressCallback } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: 'PaLM2',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
user,
});
}
},
});
const abortController = new AbortController();
const isUserProvided = process.env.PALM_KEY === 'user_provided';
let key;
if (endpointOption.key && isUserProvided) {
checkUserKeyExpiry(
endpointOption.key,
'Your GOOGLE_TOKEN has expired. Please provide your token again.',
);
key = await getUserKey({ userId: user, name: 'google' });
key = JSON.parse(key);
delete endpointOption.key;
console.log('Using service account key provided by User for PaLM models');
}
try {
key = require('../../../data/auth.json');
} catch (e) {
console.log('No \'auth.json\' file (service account key) found in /api/data/ for PaLM models');
}
const clientOptions = {
// debug: true, // for testing
reverseProxyUrl: process.env.GOOGLE_REVERSE_PROXY || null,
proxy: process.env.PROXY || null,
...endpointOption,
};
const client = new GoogleClient(key, clientOptions);
let response = await client.sendMessage(text, {
getReqData,
user,
conversationId,
parentMessageId,
overrideParentMessageId,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
await saveConvo(user, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'google',
});
await saveMessage({ ...response, user });
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
console.error(error);
const errorMessage = {
messageId: responseMessageId,
sender: 'PaLM2',
conversationId,
parentMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message,
};
await saveMessage({ ...errorMessage, user });
handleError(res, errorMessage);
}
};
module.exports = router;

View file

@ -1,11 +1,11 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { validateTools } = require('../../../app');
const { addTitle } = require('../endpoints/openAI');
const { initializeClient } = require('../endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
const { getResponseSender } = require('~/server/services/Endpoints');
const { validateTools } = require('~/app');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const {
handleAbort,
createAbortController,
@ -13,7 +13,7 @@ const {
setHeaders,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
} = require('~/server/middleware');
router.post('/abort', handleAbort());

View file

@ -1,11 +1,12 @@
const express = require('express');
const router = express.Router();
const openAI = require('./openAI');
const google = require('./google');
const bingAI = require('./bingAI');
const anthropic = require('./anthropic');
const gptPlugins = require('./gptPlugins');
const askChatGPTBrowser = require('./askChatGPTBrowser');
const anthropic = require('./anthropic');
const { isEnabled } = require('~/server/utils');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const {
uaParser,
checkBan,
@ -13,12 +14,12 @@ const {
concurrentLimiter,
messageIpLimiter,
messageUserLimiter,
} = require('../../middleware');
const { isEnabled } = require('../../utils');
const { EModelEndpoint } = require('../endpoints/schemas');
} = require('~/server/middleware');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
@ -36,10 +37,10 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
}
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.bingAI}`, bingAI);
router.use(`/${EModelEndpoint.chatGPTBrowser}`, askChatGPTBrowser);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.bingAI}`, bingAI);
module.exports = router;

View file

@ -2,8 +2,8 @@ const express = require('express');
const router = express.Router();
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { getResponseSender } = require('~/server/routes/endpoints/schemas');
const { addTitle, initializeClient } = require('~/server/routes/endpoints/openAI');
const { getResponseSender } = require('~/server/services/Endpoints');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI');
const {
handleAbort,
createAbortController,

View file

@ -1,147 +1,19 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/anthropic');
const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
handleAbort,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const router = express.Router();
router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let promptTokens;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
const userMessageId = parentMessageId;
const user = req.user.id;
const addMetadata = (data) => (metadata = data);
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
getReqData,
onStart,
addMetadata,
abortController,
});
if (metadata) {
response = { ...response, ...metadata };
}
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
await saveMessage({ ...response, user });
await saveMessage(userMessage);
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await EditController(req, res, next, initializeClient);
});
module.exports = router;

View file

@ -0,0 +1,19 @@
const express = require('express');
const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await EditController(req, res, next, initializeClient);
});
module.exports = router;

View file

@ -1,10 +1,10 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { validateTools } = require('../../../app');
const { initializeClient } = require('../endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils');
const { validateTools } = require('~/app');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { getResponseSender } = require('~/server/services/Endpoints');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const {
handleAbort,
createAbortController,
@ -12,7 +12,7 @@ const {
setHeaders,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
} = require('~/server/middleware');
router.post('/abort', handleAbort());

View file

@ -1,20 +1,23 @@
const express = require('express');
const router = express.Router();
const openAI = require('./openAI');
const gptPlugins = require('./gptPlugins');
const google = require('./google');
const anthropic = require('./anthropic');
const gptPlugins = require('./gptPlugins');
const { isEnabled } = require('~/server/utils');
const { EModelEndpoint } = require('~/server/services/Endpoints');
const {
checkBan,
uaParser,
requireJwtAuth,
concurrentLimiter,
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
} = require('../../middleware');
const { isEnabled } = require('../../utils');
} = require('~/server/middleware');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
@ -31,8 +34,9 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use(['/azureOpenAI', '/openAI'], openAI);
router.use('/gptPlugins', gptPlugins);
router.use('/anthropic', anthropic);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
module.exports = router;

View file

@ -1,9 +1,9 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/openAI');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
const { getResponseSender } = require('~/server/services/Endpoints');
const { initializeClient } = require('~/server/services/Endpoints/openAI');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const {
handleAbort,
createAbortController,
@ -11,7 +11,7 @@ const {
setHeaders,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
} = require('~/server/middleware');
router.post('/abort', handleAbort());

View file

@ -1,15 +0,0 @@
const buildOptions = (endpoint, parsedBody) => {
const { modelLabel, promptPrefix, ...rest } = parsedBody;
const endpointOption = {
endpoint,
modelLabel,
promptPrefix,
modelOptions: {
...rest,
},
};
return endpointOption;
};
module.exports = buildOptions;

View file

@ -1,8 +0,0 @@
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');
module.exports = {
// addTitle, // todo
buildOptions,
initializeClient,
};

View file

@ -1,37 +0,0 @@
const { AnthropicClient } = require('~/app');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const initializeClient = async ({ req, res, endpointOption }) => {
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY } = process.env;
const expiresAt = req.body.key;
const isUserProvided = ANTHROPIC_API_KEY === 'user_provided';
const anthropicApiKey = isUserProvided
? await getAnthropicUserKey(req.user.id)
: ANTHROPIC_API_KEY;
if (expiresAt && isUserProvided) {
checkUserKeyExpiry(
expiresAt,
'Your ANTHROPIC_API_KEY has expired. Please provide your API key again.',
);
}
const client = new AnthropicClient(anthropicApiKey, {
req,
res,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
...endpointOption,
});
return {
client,
anthropicApiKey,
};
};
const getAnthropicUserKey = async (userId) => {
return await getUserKey({ userId, name: 'anthropic' });
};
module.exports = initializeClient;

View file

@ -1,31 +0,0 @@
const buildOptions = (endpoint, parsedBody) => {
const {
chatGptLabel,
promptPrefix,
agentOptions,
tools,
model,
temperature,
top_p,
presence_penalty,
frequency_penalty,
} = parsedBody;
const endpointOption = {
endpoint,
tools: tools.map((tool) => tool.pluginKey) ?? [],
chatGptLabel,
promptPrefix,
agentOptions,
modelOptions: {
model,
temperature,
top_p,
presence_penalty,
frequency_penalty,
},
};
return endpointOption;
};
module.exports = buildOptions;

View file

@ -1,7 +0,0 @@
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');
module.exports = {
buildOptions,
initializeClient,
};

View file

@ -1,65 +0,0 @@
const { PluginsClient } = require('../../../../app');
const { isEnabled } = require('../../../utils');
const { getAzureCredentials } = require('../../../../utils');
const { getUserKey, checkUserKeyExpiry } = require('../../../services/UserService');
const initializeClient = async ({ req, res, endpointOption }) => {
const {
PROXY,
OPENAI_API_KEY,
AZURE_API_KEY,
PLUGINS_USE_AZURE,
OPENAI_REVERSE_PROXY,
OPENAI_SUMMARIZE,
DEBUG_PLUGINS,
} = process.env;
const { key: expiresAt } = req.body;
const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null;
const clientOptions = {
contextStrategy,
debug: isEnabled(DEBUG_PLUGINS),
reverseProxyUrl: OPENAI_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
req,
res,
...endpointOption,
};
const useAzure = isEnabled(PLUGINS_USE_AZURE);
const isUserProvided = useAzure
? AZURE_API_KEY === 'user_provided'
: OPENAI_API_KEY === 'user_provided';
let userKey = null;
if (expiresAt && isUserProvided) {
checkUserKeyExpiry(
expiresAt,
'Your OpenAI API key has expired. Please provide your API key again.',
);
userKey = await getUserKey({
userId: req.user.id,
name: useAzure ? 'azureOpenAI' : 'openAI',
});
}
let apiKey = isUserProvided ? userKey : OPENAI_API_KEY;
if (useAzure || (apiKey && apiKey.includes('azure') && !clientOptions.azure)) {
clientOptions.azure = isUserProvided ? JSON.parse(userKey) : getAzureCredentials();
apiKey = clientOptions.azure.azureOpenAIApiKey;
}
if (!apiKey) {
throw new Error('API key not provided.');
}
const client = new PluginsClient(apiKey, clientOptions);
return {
client,
azure: clientOptions.azure,
openAIApiKey: apiKey,
};
};
module.exports = initializeClient;

View file

@ -1,218 +0,0 @@
// gptPlugins/initializeClient.spec.js
const initializeClient = require('./initializeClient');
const { PluginsClient } = require('../../../../app');
const { getUserKey } = require('../../../services/UserService');
// Mock getUserKey since it's the only function we want to mock
jest.mock('../../../services/UserService', () => ({
getUserKey: jest.fn(),
checkUserKeyExpiry: jest.requireActual('../../../services/UserService').checkUserKeyExpiry,
}));
describe('gptPlugins/initializeClient', () => {
// Set up environment variables
const originalEnvironment = process.env;
beforeEach(() => {
jest.resetModules(); // Clears the cache
process.env = { ...originalEnvironment }; // Make a copy
});
afterAll(() => {
process.env = originalEnvironment; // Restore original env vars
});
test('should initialize PluginsClient with OpenAI API key and default options', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.PLUGINS_USE_AZURE = 'false';
process.env.DEBUG_PLUGINS = 'false';
process.env.OPENAI_SUMMARIZE = 'false';
const req = {
body: { key: null },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client, openAIApiKey } = await initializeClient({ req, res, endpointOption });
expect(openAIApiKey).toBe('test-openai-api-key');
expect(client).toBeInstanceOf(PluginsClient);
});
test('should initialize PluginsClient with Azure credentials when PLUGINS_USE_AZURE is true', async () => {
process.env.AZURE_API_KEY = 'test-azure-api-key';
(process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_VERSION = 'some-value'),
(process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'),
(process.env.PLUGINS_USE_AZURE = 'true');
process.env.DEBUG_PLUGINS = 'false';
process.env.OPENAI_SUMMARIZE = 'false';
const req = {
body: { key: null },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'test-model' } };
const { client, azure } = await initializeClient({ req, res, endpointOption });
expect(azure.azureOpenAIApiKey).toBe('test-azure-api-key');
expect(client).toBeInstanceOf(PluginsClient);
});
test('should use the debug option when DEBUG_PLUGINS is enabled', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.DEBUG_PLUGINS = 'true';
const req = {
body: { key: null },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client } = await initializeClient({ req, res, endpointOption });
expect(client.options.debug).toBe(true);
});
test('should set contextStrategy to summarize when OPENAI_SUMMARIZE is enabled', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.OPENAI_SUMMARIZE = 'true';
const req = {
body: { key: null },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client } = await initializeClient({ req, res, endpointOption });
expect(client.options.contextStrategy).toBe('summarize');
});
// ... additional tests for reverseProxyUrl, proxy, user-provided keys, etc.
test('should throw an error if no API keys are provided in the environment', async () => {
// Clear the environment variables for API keys
delete process.env.OPENAI_API_KEY;
delete process.env.AZURE_API_KEY;
const req = {
body: { key: null },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
'API key not provided.',
);
});
// Additional tests for gptPlugins/initializeClient.spec.js
// ... (previous test setup code)
test('should handle user-provided OpenAI keys and check expiry', async () => {
process.env.OPENAI_API_KEY = 'user_provided';
process.env.PLUGINS_USE_AZURE = 'false';
const futureDate = new Date(Date.now() + 10000).toISOString();
const req = {
body: { key: futureDate },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
getUserKey.mockResolvedValue('test-user-provided-openai-api-key');
const { openAIApiKey } = await initializeClient({ req, res, endpointOption });
expect(openAIApiKey).toBe('test-user-provided-openai-api-key');
});
test('should handle user-provided Azure keys and check expiry', async () => {
process.env.AZURE_API_KEY = 'user_provided';
process.env.PLUGINS_USE_AZURE = 'true';
const futureDate = new Date(Date.now() + 10000).toISOString();
const req = {
body: { key: futureDate },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'test-model' } };
getUserKey.mockResolvedValue(
JSON.stringify({
azureOpenAIApiKey: 'test-user-provided-azure-api-key',
azureOpenAIApiDeploymentName: 'test-deployment',
}),
);
const { azure } = await initializeClient({ req, res, endpointOption });
expect(azure.azureOpenAIApiKey).toBe('test-user-provided-azure-api-key');
});
test('should throw an error if the user-provided key has expired', async () => {
process.env.OPENAI_API_KEY = 'user_provided';
process.env.PLUGINS_USE_AZURE = 'FALSE';
const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired
const req = {
body: { key: expiresAt },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
/Your OpenAI API key has expired/,
);
});
test('should throw an error if the user-provided Azure key is invalid JSON', async () => {
process.env.AZURE_API_KEY = 'user_provided';
process.env.PLUGINS_USE_AZURE = 'true';
const req = {
body: { key: new Date(Date.now() + 10000).toISOString() },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
// Simulate an invalid JSON string returned from getUserKey
getUserKey.mockResolvedValue('invalid-json');
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
/Unexpected token/,
);
});
test('should correctly handle the presence of a reverse proxy', async () => {
process.env.OPENAI_REVERSE_PROXY = 'http://reverse.proxy';
process.env.PROXY = 'http://proxy';
process.env.OPENAI_API_KEY = 'test-openai-api-key';
const req = {
body: { key: null },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client } = await initializeClient({ req, res, endpointOption });
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
expect(client.options.proxy).toBe('http://proxy');
});
});

View file

@ -1,22 +0,0 @@
const { saveConvo } = require('~/models');
const { isEnabled } = require('~/server/utils');
const addTitle = async (req, { text, response, client }) => {
const { TITLE_CONVO = 'true' } = process.env ?? {};
if (!isEnabled(TITLE_CONVO)) {
return;
}
// If the request was aborted, don't generate the title.
if (client.abortController.signal.aborted) {
return;
}
const title = await client.titleConvo({ text, responseText: response?.text });
await saveConvo(req.user.id, {
conversationId: response.conversationId,
title,
});
};
module.exports = addTitle;

View file

@ -1,15 +0,0 @@
const buildOptions = (endpoint, parsedBody) => {
const { chatGptLabel, promptPrefix, ...rest } = parsedBody;
const endpointOption = {
endpoint,
chatGptLabel,
promptPrefix,
modelOptions: {
...rest,
},
};
return endpointOption;
};
module.exports = buildOptions;

View file

@ -1,9 +0,0 @@
const addTitle = require('./addTitle');
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');
module.exports = {
addTitle,
buildOptions,
initializeClient,
};

View file

@ -1,61 +0,0 @@
const { OpenAIClient } = require('~/app');
const { isEnabled } = require('~/server/utils');
const { getAzureCredentials } = require('~/utils');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const initializeClient = async ({ req, res, endpointOption }) => {
const {
PROXY,
OPENAI_API_KEY,
AZURE_API_KEY,
OPENAI_REVERSE_PROXY,
OPENAI_SUMMARIZE,
DEBUG_OPENAI,
} = process.env;
const { key: expiresAt, endpoint } = req.body;
const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null;
const clientOptions = {
debug: isEnabled(DEBUG_OPENAI),
contextStrategy,
reverseProxyUrl: OPENAI_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
req,
res,
...endpointOption,
};
const credentials = {
openAI: OPENAI_API_KEY,
azureOpenAI: AZURE_API_KEY,
};
const isUserProvided = credentials[endpoint] === 'user_provided';
let userKey = null;
if (expiresAt && isUserProvided) {
checkUserKeyExpiry(
expiresAt,
'Your OpenAI API key has expired. Please provide your API key again.',
);
userKey = await getUserKey({ userId: req.user.id, name: endpoint });
}
let apiKey = isUserProvided ? userKey : credentials[endpoint];
if (endpoint === 'azureOpenAI') {
clientOptions.azure = isUserProvided ? JSON.parse(userKey) : getAzureCredentials();
apiKey = clientOptions.azure.azureOpenAIApiKey;
}
if (!apiKey) {
throw new Error('API key not provided.');
}
const client = new OpenAIClient(apiKey, clientOptions);
return {
client,
openAIApiKey: apiKey,
};
};
module.exports = initializeClient;

View file

@ -1,199 +0,0 @@
const initializeClient = require('./initializeClient');
const { OpenAIClient } = require('../../../../app');
const { getUserKey } = require('../../../services/UserService');
// Mock getUserKey since it's the only function we want to mock
jest.mock('../../../services/UserService', () => ({
getUserKey: jest.fn(),
checkUserKeyExpiry: jest.requireActual('../../../services/UserService').checkUserKeyExpiry,
}));
describe('initializeClient', () => {
// Set up environment variables
const originalEnvironment = process.env;
beforeEach(() => {
jest.resetModules(); // Clears the cache
process.env = { ...originalEnvironment }; // Make a copy
});
afterAll(() => {
process.env = originalEnvironment; // Restore original env vars
});
test('should initialize client with OpenAI API key and default options', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.DEBUG_OPENAI = 'false';
process.env.OPENAI_SUMMARIZE = 'false';
const req = {
body: { key: null, endpoint: 'openAI' },
user: { id: '123' },
};
const res = {};
const endpointOption = {};
const client = await initializeClient({ req, res, endpointOption });
expect(client.openAIApiKey).toBe('test-openai-api-key');
expect(client.client).toBeInstanceOf(OpenAIClient);
});
test('should initialize client with Azure credentials when endpoint is azureOpenAI', async () => {
process.env.AZURE_API_KEY = 'test-azure-api-key';
(process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_VERSION = 'some-value'),
(process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'),
(process.env.OPENAI_API_KEY = 'test-openai-api-key');
process.env.DEBUG_OPENAI = 'false';
process.env.OPENAI_SUMMARIZE = 'false';
const req = {
body: { key: null, endpoint: 'azureOpenAI' },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'test-model' } };
const client = await initializeClient({ req, res, endpointOption });
expect(client.openAIApiKey).toBe('test-azure-api-key');
expect(client.client).toBeInstanceOf(OpenAIClient);
});
test('should use the debug option when DEBUG_OPENAI is enabled', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.DEBUG_OPENAI = 'true';
const req = {
body: { key: null, endpoint: 'openAI' },
user: { id: '123' },
};
const res = {};
const endpointOption = {};
const client = await initializeClient({ req, res, endpointOption });
expect(client.client.options.debug).toBe(true);
});
test('should set contextStrategy to summarize when OPENAI_SUMMARIZE is enabled', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.OPENAI_SUMMARIZE = 'true';
const req = {
body: { key: null, endpoint: 'openAI' },
user: { id: '123' },
};
const res = {};
const endpointOption = {};
const client = await initializeClient({ req, res, endpointOption });
expect(client.client.options.contextStrategy).toBe('summarize');
});
test('should set reverseProxyUrl and proxy when they are provided in the environment', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.OPENAI_REVERSE_PROXY = 'http://reverse.proxy';
process.env.PROXY = 'http://proxy';
const req = {
body: { key: null, endpoint: 'openAI' },
user: { id: '123' },
};
const res = {};
const endpointOption = {};
const client = await initializeClient({ req, res, endpointOption });
expect(client.client.options.reverseProxyUrl).toBe('http://reverse.proxy');
expect(client.client.options.proxy).toBe('http://proxy');
});
test('should throw an error if the user-provided key has expired', async () => {
process.env.OPENAI_API_KEY = 'user_provided';
process.env.AZURE_API_KEY = 'user_provided';
process.env.DEBUG_OPENAI = 'false';
process.env.OPENAI_SUMMARIZE = 'false';
const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired
const req = {
body: { key: expiresAt, endpoint: 'openAI' },
user: { id: '123' },
};
const res = {};
const endpointOption = {};
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
'Your OpenAI API key has expired. Please provide your API key again.',
);
});
test('should throw an error if no API keys are provided in the environment', async () => {
// Clear the environment variables for API keys
delete process.env.OPENAI_API_KEY;
delete process.env.AZURE_API_KEY;
const req = {
body: { key: null, endpoint: 'openAI' },
user: { id: '123' },
};
const res = {};
const endpointOption = {};
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
'API key not provided.',
);
});
it('should handle user-provided keys and check expiry', async () => {
// Set up the req.body to simulate user-provided key scenario
const req = {
body: {
key: new Date(Date.now() + 10000).toISOString(),
endpoint: 'openAI',
},
user: {
id: '123',
},
};
const res = {};
const endpointOption = {};
// Ensure the environment variable is set to 'user_provided' to match the isUserProvided condition
process.env.OPENAI_API_KEY = 'user_provided';
// Mock getUserKey to return the expected key
getUserKey.mockResolvedValue('test-user-provided-openai-api-key');
// Call the initializeClient function
const result = await initializeClient({ req, res, endpointOption });
// Assertions
expect(result.openAIApiKey).toBe('test-user-provided-openai-api-key');
});
test('should throw an error if the user-provided key is invalid', async () => {
const invalidKey = new Date(Date.now() - 100000).toISOString();
const req = {
body: { key: invalidKey, endpoint: 'openAI' },
user: { id: '123' },
};
const res = {};
const endpointOption = {};
// Ensure the environment variable is set to 'user_provided' to match the isUserProvided condition
process.env.OPENAI_API_KEY = 'user_provided';
// Mock getUserKey to return an invalid key
getUserKey.mockResolvedValue(invalidKey);
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
/Your OpenAI API key has expired/,
);
});
});

View file

@ -1,402 +0,0 @@
const { z } = require('zod');
const EModelEndpoint = {
azureOpenAI: 'azureOpenAI',
openAI: 'openAI',
bingAI: 'bingAI',
chatGPTBrowser: 'chatGPTBrowser',
google: 'google',
gptPlugins: 'gptPlugins',
anthropic: 'anthropic',
assistant: 'assistant',
};
const alternateName = {
[EModelEndpoint.openAI]: 'OpenAI',
[EModelEndpoint.assistant]: 'Assistants',
[EModelEndpoint.azureOpenAI]: 'Azure OpenAI',
[EModelEndpoint.bingAI]: 'Bing',
[EModelEndpoint.chatGPTBrowser]: 'ChatGPT',
[EModelEndpoint.gptPlugins]: 'Plugins',
[EModelEndpoint.google]: 'PaLM',
[EModelEndpoint.anthropic]: 'Anthropic',
};
const supportsFiles = {
[EModelEndpoint.openAI]: true,
[EModelEndpoint.assistant]: true,
};
const openAIModels = [
'gpt-3.5-turbo-16k-0613',
'gpt-3.5-turbo-16k',
'gpt-4-1106-preview',
'gpt-3.5-turbo',
'gpt-3.5-turbo-1106',
'gpt-4-vision-preview',
'gpt-4',
'gpt-3.5-turbo-instruct-0914',
'gpt-3.5-turbo-0613',
'gpt-3.5-turbo-0301',
'gpt-3.5-turbo-instruct',
'gpt-4-0613',
'text-davinci-003',
'gpt-4-0314',
];
const visionModels = ['gpt-4-vision', 'llava-13b'];
const eModelEndpointSchema = z.nativeEnum(EModelEndpoint);
const tPluginAuthConfigSchema = z.object({
authField: z.string(),
label: z.string(),
description: z.string(),
});
const tPluginSchema = z.object({
name: z.string(),
pluginKey: z.string(),
description: z.string(),
icon: z.string(),
authConfig: z.array(tPluginAuthConfigSchema),
authenticated: z.boolean().optional(),
isButton: z.boolean().optional(),
});
const tExampleSchema = z.object({
input: z.object({
content: z.string(),
}),
output: z.object({
content: z.string(),
}),
});
const tAgentOptionsSchema = z.object({
agent: z.string(),
skipCompletion: z.boolean(),
model: z.string(),
temperature: z.number(),
});
const tConversationSchema = z.object({
conversationId: z.string().nullable(),
title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'),
user: z.string().optional(),
endpoint: eModelEndpointSchema.nullable(),
suggestions: z.array(z.string()).optional(),
messages: z.array(z.string()).optional(),
tools: z.array(tPluginSchema).optional(),
createdAt: z.string(),
updatedAt: z.string(),
systemMessage: z.string().nullable().optional(),
modelLabel: z.string().nullable().optional(),
examples: z.array(tExampleSchema).optional(),
chatGptLabel: z.string().nullable().optional(),
userLabel: z.string().optional(),
model: z.string().nullable().optional(),
promptPrefix: z.string().nullable().optional(),
temperature: z.number().optional(),
topP: z.number().optional(),
topK: z.number().optional(),
context: z.string().nullable().optional(),
top_p: z.number().optional(),
frequency_penalty: z.number().optional(),
presence_penalty: z.number().optional(),
jailbreak: z.boolean().optional(),
jailbreakConversationId: z.string().nullable().optional(),
conversationSignature: z.string().nullable().optional(),
parentMessageId: z.string().optional(),
clientId: z.string().nullable().optional(),
invocationId: z.number().nullable().optional(),
toneStyle: z.string().nullable().optional(),
maxOutputTokens: z.number().optional(),
agentOptions: tAgentOptionsSchema.nullable().optional(),
});
const openAISchema = tConversationSchema
.pick({
model: true,
chatGptLabel: true,
promptPrefix: true,
temperature: true,
top_p: true,
presence_penalty: true,
frequency_penalty: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'gpt-3.5-turbo',
chatGptLabel: obj.chatGptLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 1,
top_p: obj.top_p ?? 1,
presence_penalty: obj.presence_penalty ?? 0,
frequency_penalty: obj.frequency_penalty ?? 0,
}))
.catch(() => ({
model: 'gpt-3.5-turbo',
chatGptLabel: null,
promptPrefix: null,
temperature: 1,
top_p: 1,
presence_penalty: 0,
frequency_penalty: 0,
}));
const googleSchema = tConversationSchema
.pick({
model: true,
modelLabel: true,
promptPrefix: true,
examples: true,
temperature: true,
maxOutputTokens: true,
topP: true,
topK: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'chat-bison',
modelLabel: obj.modelLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 0.2,
maxOutputTokens: obj.maxOutputTokens ?? 1024,
topP: obj.topP ?? 0.95,
topK: obj.topK ?? 40,
}))
.catch(() => ({
model: 'chat-bison',
modelLabel: null,
promptPrefix: null,
temperature: 0.2,
maxOutputTokens: 1024,
topP: 0.95,
topK: 40,
}));
const bingAISchema = tConversationSchema
.pick({
jailbreak: true,
systemMessage: true,
context: true,
toneStyle: true,
jailbreakConversationId: true,
conversationSignature: true,
clientId: true,
invocationId: true,
})
.transform((obj) => ({
...obj,
model: '',
jailbreak: obj.jailbreak ?? false,
systemMessage: obj.systemMessage ?? null,
context: obj.context ?? null,
toneStyle: obj.toneStyle ?? 'creative',
jailbreakConversationId: obj.jailbreakConversationId ?? null,
conversationSignature: obj.conversationSignature ?? null,
clientId: obj.clientId ?? null,
invocationId: obj.invocationId ?? 1,
}))
.catch(() => ({
model: '',
jailbreak: false,
systemMessage: null,
context: null,
toneStyle: 'creative',
jailbreakConversationId: null,
conversationSignature: null,
clientId: null,
invocationId: 1,
}));
const anthropicSchema = tConversationSchema
.pick({
model: true,
modelLabel: true,
promptPrefix: true,
temperature: true,
maxOutputTokens: true,
topP: true,
topK: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'claude-1',
modelLabel: obj.modelLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 1,
maxOutputTokens: obj.maxOutputTokens ?? 4000,
topP: obj.topP ?? 0.7,
topK: obj.topK ?? 5,
}))
.catch(() => ({
model: 'claude-1',
modelLabel: null,
promptPrefix: null,
temperature: 1,
maxOutputTokens: 4000,
topP: 0.7,
topK: 5,
}));
const chatGPTBrowserSchema = tConversationSchema
.pick({
model: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'text-davinci-002-render-sha',
}))
.catch(() => ({
model: 'text-davinci-002-render-sha',
}));
const gptPluginsSchema = tConversationSchema
.pick({
model: true,
chatGptLabel: true,
promptPrefix: true,
temperature: true,
top_p: true,
presence_penalty: true,
frequency_penalty: true,
tools: true,
agentOptions: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'gpt-3.5-turbo',
chatGptLabel: obj.chatGptLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 0.8,
top_p: obj.top_p ?? 1,
presence_penalty: obj.presence_penalty ?? 0,
frequency_penalty: obj.frequency_penalty ?? 0,
tools: obj.tools ?? [],
agentOptions: obj.agentOptions ?? {
agent: 'functions',
skipCompletion: true,
model: 'gpt-3.5-turbo',
temperature: 0,
},
}))
.catch(() => ({
model: 'gpt-3.5-turbo',
chatGptLabel: null,
promptPrefix: null,
temperature: 0.8,
top_p: 1,
presence_penalty: 0,
frequency_penalty: 0,
tools: [],
agentOptions: {
agent: 'functions',
skipCompletion: true,
model: 'gpt-3.5-turbo',
temperature: 0,
},
}));
const assistantSchema = tConversationSchema
.pick({
model: true,
assistant_id: true,
thread_id: true,
})
.transform((obj) => {
const newObj = { ...obj };
Object.keys(newObj).forEach((key) => {
const value = newObj[key];
if (value === undefined || value === null) {
delete newObj[key];
}
});
return newObj;
})
.catch(() => ({}));
const endpointSchemas = {
[EModelEndpoint.openAI]: openAISchema,
[EModelEndpoint.assistant]: assistantSchema,
[EModelEndpoint.azureOpenAI]: openAISchema,
[EModelEndpoint.google]: googleSchema,
[EModelEndpoint.bingAI]: bingAISchema,
[EModelEndpoint.anthropic]: anthropicSchema,
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowserSchema,
[EModelEndpoint.gptPlugins]: gptPluginsSchema,
};
function getFirstDefinedValue(possibleValues) {
let returnValue;
for (const value of possibleValues) {
if (value) {
returnValue = value;
break;
}
}
return returnValue;
}
const parseConvo = (endpoint, conversation, possibleValues) => {
const schema = endpointSchemas[endpoint];
if (!schema) {
throw new Error(`Unknown endpoint: ${endpoint}`);
}
const convo = schema.parse(conversation);
if (possibleValues && convo) {
convo.model = getFirstDefinedValue(possibleValues.model) ?? convo.model;
}
return convo;
};
const getResponseSender = (endpointOption) => {
const { model, endpoint, chatGptLabel, modelLabel, jailbreak } = endpointOption;
if (
[
EModelEndpoint.openAI,
EModelEndpoint.azureOpenAI,
EModelEndpoint.gptPlugins,
EModelEndpoint.chatGPTBrowser,
].includes(endpoint)
) {
if (chatGptLabel) {
return chatGptLabel;
} else if (model && model.includes('gpt-3')) {
return 'GPT-3.5';
} else if (model && model.includes('gpt-4')) {
return 'GPT-4';
}
return alternateName[endpoint] ?? 'ChatGPT';
}
if (endpoint === EModelEndpoint.bingAI) {
return jailbreak ? 'Sydney' : 'BingAI';
}
if (endpoint === EModelEndpoint.anthropic) {
return modelLabel ?? 'Claude';
}
if (endpoint === EModelEndpoint.google) {
return modelLabel ?? 'PaLM2';
}
return '';
};
module.exports = {
parseConvo,
getResponseSender,
EModelEndpoint,
supportsFiles,
openAIModels,
visionModels,
alternateName,
};