mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 17:00:15 +01:00
feat(Google): Support all Text/Chat Models, Response streaming, PaLM -> Google 🤖 (#1316)
* feat: update PaLM icons * feat: add additional google models * POC: formatting inputs for Vertex AI streaming * refactor: move endpoints services outside of /routes dir to /services/Endpoints * refactor: shorten schemas import * refactor: rename PALM to GOOGLE * feat: make Google editable endpoint * feat: reusable Ask and Edit controllers based off Anthropic * chore: organize imports/logic * fix(parseConvo): include examples in googleSchema * fix: google only allows odd number of messages to be sent * fix: pass proxy to AnthropicClient * refactor: change `google` altName to `Google` * refactor: update getModelMaxTokens and related functions to handle maxTokensMap with nested endpoint model key/values * refactor: google Icon and response sender changes (Codey and Google logo instead of PaLM in all cases) * feat: google support for maxTokensMap * feat: google updated endpoints with Ask/Edit controllers, buildOptions, and initializeClient * feat(GoogleClient): now builds prompt for text models and supports real streaming from Vertex AI through langchain * chore(GoogleClient): remove comments, left before for reference in git history * docs: update google instructions (WIP) * docs(apis_and_tokens.md): add images to google instructions * docs: remove typo apis_and_tokens.md * Update apis_and_tokens.md * feat(Google): use default settings map, fully support context for both text and chat models, fully support examples for chat models * chore: update more PaLM references to Google * chore: move playwright out of workflows to avoid failing tests
This commit is contained in:
parent
8a1968b2f8
commit
583e978a82
90 changed files with 1613 additions and 784 deletions
132
api/server/controllers/AskController.js
Normal file
132
api/server/controllers/AskController.js
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { getResponseSender } = require('~/server/services/Endpoints');
|
||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||
|
||||
const AskController = async (req, res, next, initializeClient) => {
|
||||
let {
|
||||
text,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
} = req.body;
|
||||
console.log('ask log');
|
||||
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||
let metadata;
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
|
||||
const user = req.user.id;
|
||||
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
onProgress: ({ text: partialText }) => {
|
||||
const currentTimestamp = Date.now();
|
||||
|
||||
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: partialText,
|
||||
unfinished: true,
|
||||
cancelled: false,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
}
|
||||
|
||||
if (saveDelay < 500) {
|
||||
saveDelay = 500;
|
||||
}
|
||||
},
|
||||
});
|
||||
try {
|
||||
const addMetadata = (data) => (metadata = data);
|
||||
const getAbortData = () => ({
|
||||
conversationId,
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
// debug: true,
|
||||
user,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
overrideParentMessageId,
|
||||
...endpointOption,
|
||||
onProgress: progressCallback.call(null, {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
}),
|
||||
onStart,
|
||||
getReqData,
|
||||
addMetadata,
|
||||
abortController,
|
||||
});
|
||||
|
||||
if (metadata) {
|
||||
response = { ...response, ...metadata };
|
||||
}
|
||||
|
||||
if (overrideParentMessageId) {
|
||||
response.parentMessageId = overrideParentMessageId;
|
||||
}
|
||||
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(userMessage);
|
||||
|
||||
// TODO: add title service
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
handleAbortError(res, req, error, {
|
||||
partialText,
|
||||
conversationId,
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = AskController;
|
||||
135
api/server/controllers/EditController.js
Normal file
135
api/server/controllers/EditController.js
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { getResponseSender } = require('~/server/services/Endpoints');
|
||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||
|
||||
const EditController = async (req, res, next, initializeClient) => {
|
||||
let {
|
||||
text,
|
||||
generation,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
responseMessageId,
|
||||
isContinued = false,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
} = req.body;
|
||||
console.log('edit log');
|
||||
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
|
||||
let metadata;
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
|
||||
const userMessageId = parentMessageId;
|
||||
const user = req.user.id;
|
||||
|
||||
const addMetadata = (data) => (metadata = data);
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
generation,
|
||||
onProgress: ({ text: partialText }) => {
|
||||
const currentTimestamp = Date.now();
|
||||
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: partialText,
|
||||
unfinished: true,
|
||||
cancelled: false,
|
||||
isEdited: true,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
}
|
||||
|
||||
if (saveDelay < 500) {
|
||||
saveDelay = 500;
|
||||
}
|
||||
},
|
||||
});
|
||||
try {
|
||||
const getAbortData = () => ({
|
||||
conversationId,
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
generation,
|
||||
isContinued,
|
||||
isEdited: true,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
responseMessageId,
|
||||
overrideParentMessageId,
|
||||
...endpointOption,
|
||||
onProgress: progressCallback.call(null, {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
}),
|
||||
getReqData,
|
||||
onStart,
|
||||
addMetadata,
|
||||
abortController,
|
||||
});
|
||||
|
||||
if (metadata) {
|
||||
response = { ...response, ...metadata };
|
||||
}
|
||||
|
||||
if (overrideParentMessageId) {
|
||||
response.parentMessageId = overrideParentMessageId;
|
||||
}
|
||||
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(userMessage);
|
||||
|
||||
// TODO: add title service
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
handleAbortError(res, req, error, {
|
||||
partialText,
|
||||
conversationId,
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = EditController;
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
const openAI = require('~/server/routes/endpoints/openAI');
|
||||
const gptPlugins = require('~/server/routes/endpoints/gptPlugins');
|
||||
const anthropic = require('~/server/routes/endpoints/anthropic');
|
||||
const { parseConvo, EModelEndpoint } = require('~/server/routes/endpoints/schemas');
|
||||
const { processFiles } = require('~/server/services/Files');
|
||||
const openAI = require('~/server/services/Endpoints/openAI');
|
||||
const google = require('~/server/services/Endpoints/google');
|
||||
const anthropic = require('~/server/services/Endpoints/anthropic');
|
||||
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { parseConvo, EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
|
||||
const buildFunction = {
|
||||
[EModelEndpoint.openAI]: openAI.buildOptions,
|
||||
[EModelEndpoint.google]: google.buildOptions,
|
||||
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
|
||||
[EModelEndpoint.anthropic]: anthropic.buildOptions,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
|
||||
};
|
||||
|
||||
function buildEndpointOption(req, res, next) {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
const crypto = require('crypto');
|
||||
const { sendMessage, sendError } = require('../utils');
|
||||
const { getResponseSender } = require('../routes/endpoints/schemas');
|
||||
const { saveMessage } = require('../../models');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { sendMessage, sendError } = require('~/server/utils');
|
||||
const { getResponseSender } = require('~/server/services/Endpoints');
|
||||
|
||||
/**
|
||||
* Denies a request by sending an error message and optionally saves the user's message.
|
||||
|
|
|
|||
|
|
@ -1,137 +1,19 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const { getResponseSender } = require('../endpoints/schemas');
|
||||
const { initializeClient } = require('../endpoints/anthropic');
|
||||
const AskController = require('~/server/controllers/AskController');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/anthropic');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
handleAbortError,
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
|
||||
let {
|
||||
text,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
} = req.body;
|
||||
console.log('ask log');
|
||||
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
|
||||
const user = req.user.id;
|
||||
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
onProgress: ({ text: partialText }) => {
|
||||
const currentTimestamp = Date.now();
|
||||
|
||||
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: partialText,
|
||||
unfinished: true,
|
||||
cancelled: false,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
}
|
||||
|
||||
if (saveDelay < 500) {
|
||||
saveDelay = 500;
|
||||
}
|
||||
},
|
||||
});
|
||||
try {
|
||||
const getAbortData = () => ({
|
||||
conversationId,
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
getReqData,
|
||||
// debug: true,
|
||||
user,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
overrideParentMessageId,
|
||||
...endpointOption,
|
||||
onProgress: progressCallback.call(null, {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
}),
|
||||
onStart,
|
||||
abortController,
|
||||
});
|
||||
|
||||
if (overrideParentMessageId) {
|
||||
response.parentMessageId = overrideParentMessageId;
|
||||
}
|
||||
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(userMessage);
|
||||
|
||||
// TODO: add anthropic titling
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
handleAbortError(res, req, error, {
|
||||
partialText,
|
||||
conversationId,
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
}
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -1,181 +1,19 @@
|
|||
const express = require('express');
|
||||
const AskController = require('~/server/controllers/AskController');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/google');
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
const crypto = require('crypto');
|
||||
const { GoogleClient } = require('../../../app');
|
||||
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
||||
const { handleError, sendMessage, createOnProgress } = require('../../utils');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('../../services/UserService');
|
||||
const { setHeaders } = require('../../middleware');
|
||||
|
||||
router.post('/', setHeaders, async (req, res) => {
|
||||
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
|
||||
if (text.length === 0) {
|
||||
return handleError(res, { text: 'Prompt empty or too short' });
|
||||
}
|
||||
if (endpoint !== 'google') {
|
||||
return handleError(res, { text: 'Illegal request' });
|
||||
}
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
// build endpoint option
|
||||
const endpointOption = {
|
||||
examples: req.body?.examples ?? [{ input: { content: '' }, output: { content: '' } }],
|
||||
promptPrefix: req.body?.promptPrefix ?? null,
|
||||
key: req.body?.key ?? null,
|
||||
modelOptions: {
|
||||
model: req.body?.model ?? 'chat-bison',
|
||||
modelLabel: req.body?.modelLabel ?? null,
|
||||
temperature: req.body?.temperature ?? 0.2,
|
||||
maxOutputTokens: req.body?.maxOutputTokens ?? 1024,
|
||||
topP: req.body?.topP ?? 0.95,
|
||||
topK: req.body?.topK ?? 40,
|
||||
},
|
||||
};
|
||||
|
||||
const availableModels = ['chat-bison', 'text-bison', 'codechat-bison'];
|
||||
if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) {
|
||||
return handleError(res, { text: 'Illegal request: model' });
|
||||
}
|
||||
|
||||
const conversationId = oldConversationId || crypto.randomUUID();
|
||||
|
||||
// eslint-disable-next-line no-use-before-define
|
||||
return await ask({
|
||||
text,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
req,
|
||||
res,
|
||||
});
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient);
|
||||
});
|
||||
|
||||
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
|
||||
let userMessage;
|
||||
let userMessageId;
|
||||
// let promptTokens;
|
||||
let responseMessageId;
|
||||
let lastSavedTimestamp = 0;
|
||||
const { overrideParentMessageId = null } = req.body;
|
||||
const user = req.user.id;
|
||||
|
||||
try {
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
// } else if (key === 'promptTokens') {
|
||||
// promptTokens = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
}
|
||||
}
|
||||
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
};
|
||||
|
||||
const { onProgress: progressCallback } = createOnProgress({
|
||||
onProgress: ({ text: partialText }) => {
|
||||
const currentTimestamp = Date.now();
|
||||
if (currentTimestamp - lastSavedTimestamp > 500) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender: 'PaLM2',
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
text: partialText,
|
||||
unfinished: true,
|
||||
cancelled: false,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
|
||||
const isUserProvided = process.env.PALM_KEY === 'user_provided';
|
||||
|
||||
let key;
|
||||
if (endpointOption.key && isUserProvided) {
|
||||
checkUserKeyExpiry(
|
||||
endpointOption.key,
|
||||
'Your GOOGLE_TOKEN has expired. Please provide your token again.',
|
||||
);
|
||||
key = await getUserKey({ userId: user, name: 'google' });
|
||||
key = JSON.parse(key);
|
||||
delete endpointOption.key;
|
||||
console.log('Using service account key provided by User for PaLM models');
|
||||
}
|
||||
|
||||
try {
|
||||
key = require('../../../data/auth.json');
|
||||
} catch (e) {
|
||||
console.log('No \'auth.json\' file (service account key) found in /api/data/ for PaLM models');
|
||||
}
|
||||
|
||||
const clientOptions = {
|
||||
// debug: true, // for testing
|
||||
reverseProxyUrl: process.env.GOOGLE_REVERSE_PROXY || null,
|
||||
proxy: process.env.PROXY || null,
|
||||
...endpointOption,
|
||||
};
|
||||
|
||||
const client = new GoogleClient(key, clientOptions);
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
getReqData,
|
||||
user,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
overrideParentMessageId,
|
||||
onProgress: progressCallback.call(null, {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
abortController,
|
||||
});
|
||||
|
||||
if (overrideParentMessageId) {
|
||||
response.parentMessageId = overrideParentMessageId;
|
||||
}
|
||||
|
||||
await saveConvo(user, {
|
||||
...endpointOption,
|
||||
...endpointOption.modelOptions,
|
||||
conversationId,
|
||||
endpoint: 'google',
|
||||
});
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
res.end();
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
const errorMessage = {
|
||||
messageId: responseMessageId,
|
||||
sender: 'PaLM2',
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
unfinished: false,
|
||||
cancelled: false,
|
||||
error: true,
|
||||
text: error.message,
|
||||
};
|
||||
await saveMessage({ ...errorMessage, user });
|
||||
handleError(res, errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const { getResponseSender } = require('../endpoints/schemas');
|
||||
const { validateTools } = require('../../../app');
|
||||
const { addTitle } = require('../endpoints/openAI');
|
||||
const { initializeClient } = require('../endpoints/gptPlugins');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
|
||||
const { sendMessage, createOnProgress } = require('../../utils');
|
||||
const { getResponseSender } = require('~/server/services/Endpoints');
|
||||
const { validateTools } = require('~/app');
|
||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
|
|
@ -13,7 +13,7 @@ const {
|
|||
setHeaders,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('../../middleware');
|
||||
} = require('~/server/middleware');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const openAI = require('./openAI');
|
||||
const google = require('./google');
|
||||
const bingAI = require('./bingAI');
|
||||
const anthropic = require('./anthropic');
|
||||
const gptPlugins = require('./gptPlugins');
|
||||
const askChatGPTBrowser = require('./askChatGPTBrowser');
|
||||
const anthropic = require('./anthropic');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
|
|
@ -13,12 +14,12 @@ const {
|
|||
concurrentLimiter,
|
||||
messageIpLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('../../middleware');
|
||||
const { isEnabled } = require('../../utils');
|
||||
const { EModelEndpoint } = require('../endpoints/schemas');
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
|
@ -36,10 +37,10 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
|
|||
}
|
||||
|
||||
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
|
||||
router.use(`/${EModelEndpoint.google}`, google);
|
||||
router.use(`/${EModelEndpoint.bingAI}`, bingAI);
|
||||
router.use(`/${EModelEndpoint.chatGPTBrowser}`, askChatGPTBrowser);
|
||||
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
|
||||
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
|
||||
router.use(`/${EModelEndpoint.google}`, google);
|
||||
router.use(`/${EModelEndpoint.bingAI}`, bingAI);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ const express = require('express');
|
|||
const router = express.Router();
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { getResponseSender } = require('~/server/routes/endpoints/schemas');
|
||||
const { addTitle, initializeClient } = require('~/server/routes/endpoints/openAI');
|
||||
const { getResponseSender } = require('~/server/services/Endpoints');
|
||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
|
|
|
|||
|
|
@ -1,147 +1,19 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const { getResponseSender } = require('../endpoints/schemas');
|
||||
const { initializeClient } = require('../endpoints/anthropic');
|
||||
const EditController = require('~/server/controllers/EditController');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/anthropic');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
handleAbortError,
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
|
||||
let {
|
||||
text,
|
||||
generation,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
responseMessageId,
|
||||
isContinued = false,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
} = req.body;
|
||||
console.log('edit log');
|
||||
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
|
||||
let metadata;
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
|
||||
const userMessageId = parentMessageId;
|
||||
const user = req.user.id;
|
||||
|
||||
const addMetadata = (data) => (metadata = data);
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
generation,
|
||||
onProgress: ({ text: partialText }) => {
|
||||
const currentTimestamp = Date.now();
|
||||
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: partialText,
|
||||
unfinished: true,
|
||||
cancelled: false,
|
||||
isEdited: true,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
}
|
||||
|
||||
if (saveDelay < 500) {
|
||||
saveDelay = 500;
|
||||
}
|
||||
},
|
||||
});
|
||||
try {
|
||||
const getAbortData = () => ({
|
||||
conversationId,
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
generation,
|
||||
isContinued,
|
||||
isEdited: true,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
responseMessageId,
|
||||
overrideParentMessageId,
|
||||
...endpointOption,
|
||||
onProgress: progressCallback.call(null, {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
}),
|
||||
getReqData,
|
||||
onStart,
|
||||
addMetadata,
|
||||
abortController,
|
||||
});
|
||||
|
||||
if (metadata) {
|
||||
response = { ...response, ...metadata };
|
||||
}
|
||||
|
||||
if (overrideParentMessageId) {
|
||||
response.parentMessageId = overrideParentMessageId;
|
||||
}
|
||||
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(userMessage);
|
||||
|
||||
// TODO: add anthropic titling
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
handleAbortError(res, req, error, {
|
||||
partialText,
|
||||
conversationId,
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
}
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
await EditController(req, res, next, initializeClient);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
19
api/server/routes/edit/google.js
Normal file
19
api/server/routes/edit/google.js
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
const express = require('express');
|
||||
const EditController = require('~/server/controllers/EditController');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/google');
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
await EditController(req, res, next, initializeClient);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const { getResponseSender } = require('../endpoints/schemas');
|
||||
const { validateTools } = require('../../../app');
|
||||
const { initializeClient } = require('../endpoints/gptPlugins');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
|
||||
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils');
|
||||
const { validateTools } = require('~/app');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { getResponseSender } = require('~/server/services/Endpoints');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
|
|
@ -12,7 +12,7 @@ const {
|
|||
setHeaders,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('../../middleware');
|
||||
} = require('~/server/middleware');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,23 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const openAI = require('./openAI');
|
||||
const gptPlugins = require('./gptPlugins');
|
||||
const google = require('./google');
|
||||
const anthropic = require('./anthropic');
|
||||
const gptPlugins = require('./gptPlugins');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const {
|
||||
checkBan,
|
||||
uaParser,
|
||||
requireJwtAuth,
|
||||
concurrentLimiter,
|
||||
messageIpLimiter,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('../../middleware');
|
||||
const { isEnabled } = require('../../utils');
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
|
@ -31,8 +34,9 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
|
|||
router.use(messageUserLimiter);
|
||||
}
|
||||
|
||||
router.use(['/azureOpenAI', '/openAI'], openAI);
|
||||
router.use('/gptPlugins', gptPlugins);
|
||||
router.use('/anthropic', anthropic);
|
||||
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
|
||||
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
|
||||
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
|
||||
router.use(`/${EModelEndpoint.google}`, google);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const { getResponseSender } = require('../endpoints/schemas');
|
||||
const { initializeClient } = require('../endpoints/openAI');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
|
||||
const { sendMessage, createOnProgress } = require('../../utils');
|
||||
const { getResponseSender } = require('~/server/services/Endpoints');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/openAI');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
|
|
@ -11,7 +11,7 @@ const {
|
|||
setHeaders,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('../../middleware');
|
||||
} = require('~/server/middleware');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const { EModelEndpoint } = require('~/server/routes/endpoints/schemas');
|
||||
const { EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
|
||||
const {
|
||||
OPENAI_API_KEY: openAIApiKey,
|
||||
|
|
@ -7,7 +7,7 @@ const {
|
|||
CHATGPT_TOKEN: chatGPTToken,
|
||||
BINGAI_TOKEN: bingToken,
|
||||
PLUGINS_USE_AZURE,
|
||||
PALM_KEY: palmKey,
|
||||
GOOGLE_KEY: googleKey,
|
||||
} = process.env ?? {};
|
||||
|
||||
const useAzurePlugins = !!PLUGINS_USE_AZURE;
|
||||
|
|
@ -26,7 +26,7 @@ module.exports = {
|
|||
azureOpenAIApiKey,
|
||||
useAzurePlugins,
|
||||
userProvidedOpenAI,
|
||||
palmKey,
|
||||
googleKey,
|
||||
[EModelEndpoint.openAI]: isUserProvided(openAIApiKey),
|
||||
[EModelEndpoint.assistant]: isUserProvided(openAIApiKey),
|
||||
[EModelEndpoint.azureOpenAI]: isUserProvided(azureOpenAIApiKey),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { availableTools } = require('~/app/clients/tools');
|
||||
const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs');
|
||||
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, palmKey } =
|
||||
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } =
|
||||
require('./EndpointService').config;
|
||||
|
||||
/**
|
||||
|
|
@ -8,7 +8,7 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, pa
|
|||
*/
|
||||
async function loadAsyncEndpoints() {
|
||||
let i = 0;
|
||||
let key, palmUser;
|
||||
let key, googleUserProvides;
|
||||
try {
|
||||
key = require('~/data/auth.json');
|
||||
} catch (e) {
|
||||
|
|
@ -17,8 +17,8 @@ async function loadAsyncEndpoints() {
|
|||
}
|
||||
}
|
||||
|
||||
if (palmKey === 'user_provided') {
|
||||
palmUser = true;
|
||||
if (googleKey === 'user_provided') {
|
||||
googleUserProvides = true;
|
||||
if (i <= 1) {
|
||||
i++;
|
||||
}
|
||||
|
|
@ -33,7 +33,7 @@ async function loadAsyncEndpoints() {
|
|||
}
|
||||
const plugins = transformToolsToMap(tools);
|
||||
|
||||
const google = key || palmUser ? { userProvide: palmUser } : false;
|
||||
const google = key || googleUserProvides ? { userProvide: googleUserProvides } : false;
|
||||
|
||||
const gptPlugins =
|
||||
openAIApiKey || azureOpenAIApiKey
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const { EModelEndpoint } = require('~/server/routes/endpoints/schemas');
|
||||
const { EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const loadAsyncEndpoints = require('./loadAsyncEndpoints');
|
||||
const { config } = require('./EndpointService');
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ const {
|
|||
getChatGPTBrowserModels,
|
||||
getAnthropicModels,
|
||||
} = require('~/server/services/ModelService');
|
||||
const { EModelEndpoint } = require('~/server/routes/endpoints/schemas');
|
||||
const { EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
|
||||
|
||||
const fitlerAssistantModels = (str) => {
|
||||
|
|
@ -21,7 +21,18 @@ async function loadDefaultModels() {
|
|||
[EModelEndpoint.openAI]: openAI,
|
||||
[EModelEndpoint.azureOpenAI]: azureOpenAI,
|
||||
[EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels),
|
||||
[EModelEndpoint.google]: ['chat-bison', 'text-bison', 'codechat-bison'],
|
||||
[EModelEndpoint.google]: [
|
||||
'chat-bison',
|
||||
'chat-bison-32k',
|
||||
'codechat-bison',
|
||||
'codechat-bison-32k',
|
||||
'text-bison',
|
||||
'text-bison-32k',
|
||||
'text-unicorn',
|
||||
'code-gecko',
|
||||
'code-bison',
|
||||
'code-bison-32k',
|
||||
],
|
||||
[EModelEndpoint.bingAI]: ['BingAI', 'Sydney'],
|
||||
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ const { AnthropicClient } = require('~/app');
|
|||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY } = process.env;
|
||||
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY, PROXY } = process.env;
|
||||
const expiresAt = req.body.key;
|
||||
const isUserProvided = ANTHROPIC_API_KEY === 'user_provided';
|
||||
|
||||
|
|
@ -21,6 +21,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
req,
|
||||
res,
|
||||
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
...endpointOption,
|
||||
});
|
||||
|
||||
16
api/server/services/Endpoints/google/buildOptions.js
Normal file
16
api/server/services/Endpoints/google/buildOptions.js
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
const buildOptions = (endpoint, parsedBody) => {
|
||||
const { examples, modelLabel, promptPrefix, ...rest } = parsedBody;
|
||||
const endpointOption = {
|
||||
examples,
|
||||
endpoint,
|
||||
modelLabel,
|
||||
promptPrefix,
|
||||
modelOptions: {
|
||||
...rest,
|
||||
},
|
||||
};
|
||||
|
||||
return endpointOption;
|
||||
};
|
||||
|
||||
module.exports = buildOptions;
|
||||
8
api/server/services/Endpoints/google/index.js
Normal file
8
api/server/services/Endpoints/google/index.js
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
const buildOptions = require('./buildOptions');
|
||||
const initializeClient = require('./initializeClient');
|
||||
|
||||
module.exports = {
|
||||
// addTitle, // todo
|
||||
buildOptions,
|
||||
initializeClient,
|
||||
};
|
||||
35
api/server/services/Endpoints/google/initializeClient.js
Normal file
35
api/server/services/Endpoints/google/initializeClient.js
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
const { GoogleClient } = require('~/app');
|
||||
const { EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, PROXY } = process.env;
|
||||
const isUserProvided = GOOGLE_KEY === 'user_provided';
|
||||
const { key: expiresAt } = req.body;
|
||||
|
||||
let userKey = null;
|
||||
if (expiresAt && isUserProvided) {
|
||||
checkUserKeyExpiry(
|
||||
expiresAt,
|
||||
'Your Google key has expired. Please provide your JSON credentials again.',
|
||||
);
|
||||
userKey = await getUserKey({ userId: req.user.id, name: EModelEndpoint.google });
|
||||
}
|
||||
|
||||
const apiKey = isUserProvided ? userKey : require('~/data/auth.json');
|
||||
|
||||
const client = new GoogleClient(apiKey, {
|
||||
req,
|
||||
res,
|
||||
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
...endpointOption,
|
||||
});
|
||||
|
||||
return {
|
||||
client,
|
||||
apiKey,
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = initializeClient;
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
const { PluginsClient } = require('../../../../app');
|
||||
const { isEnabled } = require('../../../utils');
|
||||
const { getAzureCredentials } = require('../../../../utils');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('../../../services/UserService');
|
||||
const { PluginsClient } = require('~/app');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { getAzureCredentials } = require('~/utils');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
const {
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
// gptPlugins/initializeClient.spec.js
|
||||
const { PluginsClient } = require('~/app');
|
||||
const initializeClient = require('./initializeClient');
|
||||
const { PluginsClient } = require('../../../../app');
|
||||
const { getUserKey } = require('../../../services/UserService');
|
||||
const { getUserKey } = require('../../UserService');
|
||||
|
||||
// Mock getUserKey since it's the only function we want to mock
|
||||
jest.mock('../../../services/UserService', () => ({
|
||||
jest.mock('~/server/services/UserService', () => ({
|
||||
getUserKey: jest.fn(),
|
||||
checkUserKeyExpiry: jest.requireActual('../../../services/UserService').checkUserKeyExpiry,
|
||||
checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry,
|
||||
}));
|
||||
|
||||
describe('gptPlugins/initializeClient', () => {
|
||||
5
api/server/services/Endpoints/index.js
Normal file
5
api/server/services/Endpoints/index.js
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
const schemas = require('./schemas');
|
||||
|
||||
module.exports = {
|
||||
...schemas,
|
||||
};
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
const { OpenAIClient } = require('~/app');
|
||||
const initializeClient = require('./initializeClient');
|
||||
const { OpenAIClient } = require('../../../../app');
|
||||
const { getUserKey } = require('../../../services/UserService');
|
||||
const { getUserKey } = require('~/server/services/UserService');
|
||||
|
||||
// Mock getUserKey since it's the only function we want to mock
|
||||
jest.mock('../../../services/UserService', () => ({
|
||||
jest.mock('~/server/services/UserService', () => ({
|
||||
getUserKey: jest.fn(),
|
||||
checkUserKeyExpiry: jest.requireActual('../../../services/UserService').checkUserKeyExpiry,
|
||||
checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry,
|
||||
}));
|
||||
|
||||
describe('initializeClient', () => {
|
||||
|
|
@ -18,10 +18,44 @@ const alternateName = {
|
|||
[EModelEndpoint.bingAI]: 'Bing',
|
||||
[EModelEndpoint.chatGPTBrowser]: 'ChatGPT',
|
||||
[EModelEndpoint.gptPlugins]: 'Plugins',
|
||||
[EModelEndpoint.google]: 'PaLM',
|
||||
[EModelEndpoint.google]: 'Google',
|
||||
[EModelEndpoint.anthropic]: 'Anthropic',
|
||||
};
|
||||
|
||||
const endpointSettings = {
|
||||
[EModelEndpoint.google]: {
|
||||
model: {
|
||||
default: 'chat-bison',
|
||||
},
|
||||
maxOutputTokens: {
|
||||
min: 1,
|
||||
max: 2048,
|
||||
step: 1,
|
||||
default: 1024,
|
||||
},
|
||||
temperature: {
|
||||
min: 0,
|
||||
max: 1,
|
||||
step: 0.01,
|
||||
default: 0.2,
|
||||
},
|
||||
topP: {
|
||||
min: 0,
|
||||
max: 1,
|
||||
step: 0.01,
|
||||
default: 0.8,
|
||||
},
|
||||
topK: {
|
||||
min: 1,
|
||||
max: 40,
|
||||
step: 0.01,
|
||||
default: 40,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const google = endpointSettings[EModelEndpoint.google];
|
||||
|
||||
const supportsFiles = {
|
||||
[EModelEndpoint.openAI]: true,
|
||||
[EModelEndpoint.assistant]: true,
|
||||
|
|
@ -158,22 +192,24 @@ const googleSchema = tConversationSchema
|
|||
})
|
||||
.transform((obj) => ({
|
||||
...obj,
|
||||
model: obj.model ?? 'chat-bison',
|
||||
model: obj.model ?? google.model.default,
|
||||
modelLabel: obj.modelLabel ?? null,
|
||||
promptPrefix: obj.promptPrefix ?? null,
|
||||
temperature: obj.temperature ?? 0.2,
|
||||
maxOutputTokens: obj.maxOutputTokens ?? 1024,
|
||||
topP: obj.topP ?? 0.95,
|
||||
topK: obj.topK ?? 40,
|
||||
examples: obj.examples ?? [{ input: { content: '' }, output: { content: '' } }],
|
||||
temperature: obj.temperature ?? google.temperature.default,
|
||||
maxOutputTokens: obj.maxOutputTokens ?? google.maxOutputTokens.default,
|
||||
topP: obj.topP ?? google.topP.default,
|
||||
topK: obj.topK ?? google.topK.default,
|
||||
}))
|
||||
.catch(() => ({
|
||||
model: 'chat-bison',
|
||||
model: google.model.default,
|
||||
modelLabel: null,
|
||||
promptPrefix: null,
|
||||
temperature: 0.2,
|
||||
maxOutputTokens: 1024,
|
||||
topP: 0.95,
|
||||
topK: 40,
|
||||
examples: [{ input: { content: '' }, output: { content: '' } }],
|
||||
temperature: google.temperature.default,
|
||||
maxOutputTokens: google.maxOutputTokens.default,
|
||||
topP: google.topP.default,
|
||||
topK: google.topK.default,
|
||||
}));
|
||||
|
||||
const bingAISchema = tConversationSchema
|
||||
|
|
@ -385,7 +421,13 @@ const getResponseSender = (endpointOption) => {
|
|||
}
|
||||
|
||||
if (endpoint === EModelEndpoint.google) {
|
||||
return modelLabel ?? 'PaLM2';
|
||||
if (modelLabel) {
|
||||
return modelLabel;
|
||||
} else if (model && model.includes('code')) {
|
||||
return 'Codey';
|
||||
}
|
||||
|
||||
return 'PaLM2';
|
||||
}
|
||||
|
||||
return '';
|
||||
|
|
@ -399,4 +441,5 @@ module.exports = {
|
|||
openAIModels,
|
||||
visionModels,
|
||||
alternateName,
|
||||
endpointSettings,
|
||||
};
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
const { visionModels } = require('~/server/routes/endpoints/schemas');
|
||||
const { visionModels } = require('~/server/services/Endpoints');
|
||||
|
||||
function validateVisionModel(model) {
|
||||
if (!model) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue