ci(backend-review.yml): add linter step to the backend review workflow (#625)

* ci(backend-review.yml): add linter step to the backend review workflow

* chore(backend-review.yml): remove prettier from lint-action configuration

* chore: apply new linting workflow

* chore(lint-staged.config.js): reorder lint-staged tasks for JavaScript and TypeScript files

* chore(eslint): update ignorePatterns in .eslintrc.js
chore(lint-action): remove prettier option in backend-review.yml
chore(package.json): add lint and lint:fix scripts

* chore(lint-staged.config.js): remove prettier --write command for js, jsx, ts, tsx files

* chore(titleConvo.js): remove unnecessary console.log statement
chore(titleConvo.js): add missing comma in options object

* chore: apply linting to all files

* chore(lint-staged.config.js): update lint-staged configuration to include prettier formatting
This commit is contained in:
Danny Avila 2023-07-14 09:36:49 -04:00 committed by GitHub
parent 637bb6bc11
commit e5336039fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
231 changed files with 1688 additions and 1526 deletions

View file

@ -4,25 +4,30 @@ module.exports = {
es2021: true, es2021: true,
node: true, node: true,
commonjs: true, commonjs: true,
es6: true es6: true,
}, },
extends: [ extends: [
'eslint:recommended', 'eslint:recommended',
'plugin:react/recommended', 'plugin:react/recommended',
'plugin:react-hooks/recommended', 'plugin:react-hooks/recommended',
'plugin:jest/recommended', 'plugin:jest/recommended',
'prettier' 'prettier',
], ],
// ignorePatterns: ['packages/data-provider/types/**/*'],
ignorePatterns: [ ignorePatterns: [
'client/dist/**/*',
'client/public/**/*',
'e2e/playwright-report/**/*',
'packages/data-provider/types/**/*', 'packages/data-provider/types/**/*',
'packages/data-provider/dist/**/*',
], ],
parser: '@typescript-eslint/parser', parser: '@typescript-eslint/parser',
parserOptions: { parserOptions: {
ecmaVersion: 'latest', ecmaVersion: 'latest',
sourceType: 'module', sourceType: 'module',
ecmaFeatures: { ecmaFeatures: {
jsx: true jsx: true,
} },
}, },
plugins: ['react', 'react-hooks', '@typescript-eslint'], plugins: ['react', 'react-hooks', '@typescript-eslint'],
rules: { rules: {
@ -35,13 +40,14 @@ module.exports = {
code: 120, code: 120,
ignoreStrings: true, ignoreStrings: true,
ignoreTemplateLiterals: true, ignoreTemplateLiterals: true,
ignoreComments: true ignoreComments: true,
} },
], ],
'linebreak-style': 0, 'linebreak-style': 0,
'object-curly-spacing': ['error', 'always'], 'object-curly-spacing': ['error', 'always'],
'no-trailing-spaces': 'error', 'no-trailing-spaces': 'error',
'no-multiple-empty-lines': ['error', { 'max': 1 }], 'no-multiple-empty-lines': ['error', { max: 1 }],
'comma-dangle': ['error', 'always-multiline'],
// "arrow-parens": [2, "as-needed", { requireForBlockBody: true }], // "arrow-parens": [2, "as-needed", { requireForBlockBody: true }],
// 'no-plusplus': ['error', { allowForLoopAfterthoughts: true }], // 'no-plusplus': ['error', { allowForLoopAfterthoughts: true }],
'no-console': 'off', 'no-console': 'off',
@ -52,6 +58,7 @@ module.exports = {
'no-restricted-syntax': 'off', 'no-restricted-syntax': 'off',
'react/prop-types': ['off'], 'react/prop-types': ['off'],
'react/display-name': ['off'], 'react/display-name': ['off'],
quotes: ['error', 'single'],
}, },
overrides: [ overrides: [
{ {
@ -59,14 +66,14 @@ module.exports = {
rules: { rules: {
'no-unused-vars': 'off', // off because it conflicts with '@typescript-eslint/no-unused-vars' 'no-unused-vars': 'off', // off because it conflicts with '@typescript-eslint/no-unused-vars'
'react/display-name': 'off', 'react/display-name': 'off',
'@typescript-eslint/no-unused-vars': 'warn' '@typescript-eslint/no-unused-vars': 'warn',
} },
}, },
{ {
files: ['rollup.config.js', '.eslintrc.js', 'jest.config.js'], files: ['rollup.config.js', '.eslintrc.js', 'jest.config.js'],
env: { env: {
node: true, node: true,
} },
}, },
{ {
files: [ files: [
@ -78,29 +85,29 @@ module.exports = {
'**/*.spec.jsx', '**/*.spec.jsx',
'**/*.spec.ts', '**/*.spec.ts',
'**/*.spec.tsx', '**/*.spec.tsx',
'setupTests.js' 'setupTests.js',
], ],
env: { env: {
jest: true, jest: true,
node: true node: true,
}, },
rules: { rules: {
'react/display-name': 'off', 'react/display-name': 'off',
'react/prop-types': 'off', 'react/prop-types': 'off',
'react/no-unescaped-entities': 'off' 'react/no-unescaped-entities': 'off',
} },
}, },
{ {
files: '**/*.+(ts)', files: '**/*.+(ts)',
parser: '@typescript-eslint/parser', parser: '@typescript-eslint/parser',
parserOptions: { parserOptions: {
project: './client/tsconfig.json' project: './client/tsconfig.json',
}, },
plugins: ['@typescript-eslint/eslint-plugin', 'jest'], plugins: ['@typescript-eslint/eslint-plugin', 'jest'],
extends: [ extends: [
'plugin:@typescript-eslint/eslint-recommended', 'plugin:@typescript-eslint/eslint-recommended',
'plugin:@typescript-eslint/recommended' 'plugin:@typescript-eslint/recommended',
] ],
}, },
{ {
files: './packages/data-provider/**/*.ts', files: './packages/data-provider/**/*.ts',
@ -109,11 +116,11 @@ module.exports = {
files: '**/*.ts', files: '**/*.ts',
parser: '@typescript-eslint/parser', parser: '@typescript-eslint/parser',
parserOptions: { parserOptions: {
project: './packages/data-provider/tsconfig.json' project: './packages/data-provider/tsconfig.json',
} },
} },
] ],
} },
], ],
settings: { settings: {
react: { react: {
@ -121,7 +128,7 @@ module.exports = {
// default to "createReactClass" // default to "createReactClass"
pragma: 'React', // Pragma to use, default to "React" pragma: 'React', // Pragma to use, default to "React"
fragment: 'Fragment', // Fragment to use (may be a property of <pragma>), default to "Fragment" fragment: 'Fragment', // Fragment to use (may be a property of <pragma>), default to "Fragment"
version: 'detect' // React version. "detect" automatically picks the version you have installed. version: 'detect', // React version. "detect" automatically picks the version you have installed.
} },
} },
}; };

View file

@ -37,3 +37,8 @@ jobs:
- name: Run unit tests - name: Run unit tests
run: cd api && npm run test:ci run: cd api && npm run test:ci
- name: Run linters
uses: wearerequired/lint-action@v2
with:
eslint: true

View file

@ -5,7 +5,7 @@ module.exports = {
semi: true, semi: true,
singleQuote: true, singleQuote: true,
// bracketSpacing: false, // bracketSpacing: false,
trailingComma: 'none', trailingComma: 'all',
arrowParens: 'always', arrowParens: 'always',
embeddedLanguageFormatting: 'auto', embeddedLanguageFormatting: 'auto',
insertPragma: false, insertPragma: false,

View file

@ -14,11 +14,11 @@ const askBing = async ({
invocationId, invocationId,
toneStyle, toneStyle,
token, token,
onProgress onProgress,
}) => { }) => {
const { BingAIClient } = await import('@waylaidwanderer/chatgpt-api'); const { BingAIClient } = await import('@waylaidwanderer/chatgpt-api');
const store = { const store = {
store: new KeyvFile({ filename: './data/cache.json' }) store: new KeyvFile({ filename: './data/cache.json' }),
}; };
const bingAIClient = new BingAIClient({ const bingAIClient = new BingAIClient({
@ -30,7 +30,7 @@ const askBing = async ({
debug: false, debug: false,
cache: store, cache: store,
host: process.env.BINGAI_HOST || null, host: process.env.BINGAI_HOST || null,
proxy: process.env.PROXY || null proxy: process.env.PROXY || null,
}); });
let options = {}; let options = {};
@ -46,7 +46,7 @@ const askBing = async ({
systemMessage, systemMessage,
parentMessageId, parentMessageId,
toneStyle, toneStyle,
onProgress onProgress,
}; };
else { else {
options = { options = {
@ -55,7 +55,7 @@ const askBing = async ({
systemMessage, systemMessage,
parentMessageId, parentMessageId,
toneStyle, toneStyle,
onProgress onProgress,
}; };
// don't give those parameters for new conversation // don't give those parameters for new conversation

View file

@ -10,11 +10,11 @@ const browserClient = async ({
onProgress, onProgress,
onEventMessage, onEventMessage,
abortController, abortController,
userId userId,
}) => { }) => {
const { ChatGPTBrowserClient } = await import('@waylaidwanderer/chatgpt-api'); const { ChatGPTBrowserClient } = await import('@waylaidwanderer/chatgpt-api');
const store = { const store = {
store: new KeyvFile({ filename: './data/cache.json' }) store: new KeyvFile({ filename: './data/cache.json' }),
}; };
const clientOptions = { const clientOptions = {
@ -27,7 +27,7 @@ const browserClient = async ({
model: model, model: model,
debug: false, debug: false,
proxy: process.env.PROXY || null, proxy: process.env.PROXY || null,
user: userId user: userId,
}; };
const client = new ChatGPTBrowserClient(clientOptions, store); const client = new ChatGPTBrowserClient(clientOptions, store);

View file

@ -3,7 +3,7 @@ const Keyv = require('keyv');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { const {
encoding_for_model: encodingForModel, encoding_for_model: encodingForModel,
get_encoding: getEncoding get_encoding: getEncoding,
} = require('@dqbd/tiktoken'); } = require('@dqbd/tiktoken');
const Anthropic = require('@anthropic-ai/sdk'); const Anthropic = require('@anthropic-ai/sdk');
@ -13,9 +13,8 @@ const AI_PROMPT = '\n\nAssistant:';
const tokenizersCache = {}; const tokenizersCache = {};
class AnthropicClient extends BaseClient { class AnthropicClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}) { constructor(apiKey, options = {}, cacheOptions = {}) {
super(apiKey, options, cacheOptions) super(apiKey, options, cacheOptions);
cacheOptions.namespace = cacheOptions.namespace || 'anthropic'; cacheOptions.namespace = cacheOptions.namespace || 'anthropic';
this.conversationsCache = new Keyv(cacheOptions); this.conversationsCache = new Keyv(cacheOptions);
this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY; this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY;
@ -30,7 +29,7 @@ class AnthropicClient extends BaseClient {
// nested options aren't spread properly, so we need to do this manually // nested options aren't spread properly, so we need to do this manually
this.options.modelOptions = { this.options.modelOptions = {
...this.options.modelOptions, ...this.options.modelOptions,
...options.modelOptions ...options.modelOptions,
}; };
delete options.modelOptions; delete options.modelOptions;
// now we can merge options // now we can merge options
@ -50,7 +49,7 @@ class AnthropicClient extends BaseClient {
temperature: typeof modelOptions.temperature === 'undefined' ? 0.7 : modelOptions.temperature, // 0 - 1, 0.7 is recommended temperature: typeof modelOptions.temperature === 'undefined' ? 0.7 : modelOptions.temperature, // 0 - 1, 0.7 is recommended
topP: typeof modelOptions.topP === 'undefined' ? 0.7 : modelOptions.topP, // 0 - 1, default: 0.7 topP: typeof modelOptions.topP === 'undefined' ? 0.7 : modelOptions.topP, // 0 - 1, default: 0.7
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40 topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
stop: modelOptions.stop // no stop method for now stop: modelOptions.stop, // no stop method for now
}; };
this.maxContextTokens = this.options.maxContextTokens || 99999; this.maxContextTokens = this.options.maxContextTokens || 99999;
@ -62,7 +61,7 @@ class AnthropicClient extends BaseClient {
throw new Error( throw new Error(
`maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ `maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})` }) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
); );
} }
@ -85,18 +84,17 @@ class AnthropicClient extends BaseClient {
} }
getClient() { getClient() {
if(this.options.reverseProxyUrl) { if (this.options.reverseProxyUrl) {
return new Anthropic({ return new Anthropic({
apiKey: this.apiKey, apiKey: this.apiKey,
baseURL: this.options.reverseProxyUrl baseURL: this.options.reverseProxyUrl,
}); });
} } else {
else {
return new Anthropic({ return new Anthropic({
apiKey: this.apiKey, apiKey: this.apiKey,
}); });
} }
}; }
async buildMessages(messages, parentMessageId) { async buildMessages(messages, parentMessageId) {
const orderedMessages = this.constructor.getMessagesForConversation(messages, parentMessageId); const orderedMessages = this.constructor.getMessagesForConversation(messages, parentMessageId);
@ -106,7 +104,7 @@ class AnthropicClient extends BaseClient {
const formattedMessages = orderedMessages.map((message) => ({ const formattedMessages = orderedMessages.map((message) => ({
author: message.isCreatedByUser ? this.userLabel : this.assistantLabel, author: message.isCreatedByUser ? this.userLabel : this.assistantLabel,
content: message?.content ?? message.text content: message?.content ?? message.text,
})); }));
let identityPrefix = ''; let identityPrefix = '';
@ -169,7 +167,9 @@ class AnthropicClient extends BaseClient {
if (newTokenCount > maxTokenCount) { if (newTokenCount > maxTokenCount) {
if (!promptBody) { if (!promptBody) {
// This is the first message, so we can't add it. Just throw an error. // This is the first message, so we can't add it. Just throw an error.
throw new Error(`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`); throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
} }
// Otherwise, ths message would put us over the token limit, so don't add it. // Otherwise, ths message would put us over the token limit, so don't add it.
@ -183,7 +183,7 @@ class AnthropicClient extends BaseClient {
promptBody = newPromptBody; promptBody = newPromptBody;
currentTokenCount = newTokenCount; currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop // wait for next tick to avoid blocking the event loop
await new Promise(resolve => setImmediate(resolve)); await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody(); return buildPromptBody();
} }
return true; return true;
@ -202,7 +202,10 @@ class AnthropicClient extends BaseClient {
currentTokenCount += 2; currentTokenCount += 2;
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.maxOutputTokens = Math.min(this.maxContextTokens - currentTokenCount, this.maxResponseTokens); this.modelOptions.maxOutputTokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
return { prompt, context }; return { prompt, context };
} }
@ -243,7 +246,7 @@ class AnthropicClient extends BaseClient {
stream: this.modelOptions.stream || true, stream: this.modelOptions.stream || true,
max_tokens_to_sample: this.modelOptions.maxOutputTokens || 1500, max_tokens_to_sample: this.modelOptions.maxOutputTokens || 1500,
metadata, metadata,
...modelOptions ...modelOptions,
}; };
if (this.options.debug) { if (this.options.debug) {
console.log('AnthropicClient: requestOptions'); console.log('AnthropicClient: requestOptions');
@ -289,7 +292,7 @@ class AnthropicClient extends BaseClient {
return { return {
promptPrefix: this.options.promptPrefix, promptPrefix: this.options.promptPrefix,
modelLabel: this.options.modelLabel, modelLabel: this.options.modelLabel,
...this.modelOptions ...this.modelOptions,
}; };
} }

View file

@ -14,7 +14,7 @@ class BaseClient {
this.currentDateString = new Date().toLocaleDateString('en-us', { this.currentDateString = new Date().toLocaleDateString('en-us', {
year: 'numeric', year: 'numeric',
month: 'long', month: 'long',
day: 'numeric' day: 'numeric',
}); });
} }
@ -58,7 +58,7 @@ class BaseClient {
const responseMessageId = crypto.randomUUID(); const responseMessageId = crypto.randomUUID();
const saveOptions = this.getSaveOptions(); const saveOptions = this.getSaveOptions();
this.abortController = opts.abortController || new AbortController(); this.abortController = opts.abortController || new AbortController();
this.currentMessages = await this.loadHistory(conversationId, parentMessageId) ?? []; this.currentMessages = (await this.loadHistory(conversationId, parentMessageId)) ?? [];
return { return {
...opts, ...opts,
@ -78,20 +78,14 @@ class BaseClient {
conversationId, conversationId,
sender: 'User', sender: 'User',
text, text,
isCreatedByUser: true isCreatedByUser: true,
}; };
return userMessage; return userMessage;
} }
async handleStartMethods(message, opts) { async handleStartMethods(message, opts) {
const { const { user, conversationId, parentMessageId, userMessageId, responseMessageId, saveOptions } =
user, await this.setMessageOptions(opts);
conversationId,
parentMessageId,
userMessageId,
responseMessageId,
saveOptions,
} = await this.setMessageOptions(opts);
const userMessage = this.createUserMessage({ const userMessage = this.createUserMessage({
messageId: userMessageId, messageId: userMessageId,
@ -104,7 +98,7 @@ class BaseClient {
opts.getIds({ opts.getIds({
userMessage, userMessage,
conversationId, conversationId,
responseMessageId responseMessageId,
}); });
} }
@ -189,24 +183,32 @@ class BaseClient {
async refineMessages(messagesToRefine, remainingContextTokens) { async refineMessages(messagesToRefine, remainingContextTokens) {
const model = new ChatOpenAI({ temperature: 0 }); const model = new ChatOpenAI({ temperature: 0 });
const chain = loadSummarizationChain(model, { type: 'refine', verbose: this.options.debug, refinePrompt }); const chain = loadSummarizationChain(model, {
type: 'refine',
verbose: this.options.debug,
refinePrompt,
});
const splitter = new RecursiveCharacterTextSplitter({ const splitter = new RecursiveCharacterTextSplitter({
chunkSize: 1500, chunkSize: 1500,
chunkOverlap: 100, chunkOverlap: 100,
}); });
const userMessages = this.concatenateMessages(messagesToRefine.filter(m => m.role === 'user')); const userMessages = this.concatenateMessages(
const assistantMessages = this.concatenateMessages(messagesToRefine.filter(m => m.role !== 'user')); messagesToRefine.filter((m) => m.role === 'user'),
const userDocs = await splitter.createDocuments([userMessages],[],{ );
const assistantMessages = this.concatenateMessages(
messagesToRefine.filter((m) => m.role !== 'user'),
);
const userDocs = await splitter.createDocuments([userMessages], [], {
chunkHeader: 'DOCUMENT NAME: User Message\n\n---\n\n', chunkHeader: 'DOCUMENT NAME: User Message\n\n---\n\n',
appendChunkOverlapHeader: true, appendChunkOverlapHeader: true,
}); });
const assistantDocs = await splitter.createDocuments([assistantMessages],[],{ const assistantDocs = await splitter.createDocuments([assistantMessages], [], {
chunkHeader: 'DOCUMENT NAME: Assistant Message\n\n---\n\n', chunkHeader: 'DOCUMENT NAME: Assistant Message\n\n---\n\n',
appendChunkOverlapHeader: true, appendChunkOverlapHeader: true,
}); });
// const chunkSize = Math.round(concatenatedMessages.length / 512); // const chunkSize = Math.round(concatenatedMessages.length / 512);
const input_documents = userDocs.concat(assistantDocs); const input_documents = userDocs.concat(assistantDocs);
if (this.options.debug ) { if (this.options.debug) {
console.debug('Refining messages...'); console.debug('Refining messages...');
} }
try { try {
@ -219,11 +221,15 @@ class BaseClient {
role: 'assistant', role: 'assistant',
content: res.output_text, content: res.output_text,
tokenCount: this.getTokenCount(res.output_text), tokenCount: this.getTokenCount(res.output_text),
} };
if (this.options.debug ) { if (this.options.debug) {
console.debug('Refined messages', refinedMessage); console.debug('Refined messages', refinedMessage);
console.debug(`remainingContextTokens: ${remainingContextTokens}, after refining: ${remainingContextTokens - refinedMessage.tokenCount}`); console.debug(
`remainingContextTokens: ${remainingContextTokens}, after refining: ${
remainingContextTokens - refinedMessage.tokenCount
}`,
);
} }
return refinedMessage; return refinedMessage;
@ -235,15 +241,15 @@ class BaseClient {
} }
/** /**
* This method processes an array of messages and returns a context of messages that fit within a token limit. * This method processes an array of messages and returns a context of messages that fit within a token limit.
* It iterates over the messages from newest to oldest, adding them to the context until the token limit is reached. * It iterates over the messages from newest to oldest, adding them to the context until the token limit is reached.
* If the token limit would be exceeded by adding a message, that message and possibly the previous one are added to a separate array of messages to refine. * If the token limit would be exceeded by adding a message, that message and possibly the previous one are added to a separate array of messages to refine.
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the arrays at the end to maintain the original order of the messages. * The method uses `push` and `pop` operations for efficient array manipulation, and reverses the arrays at the end to maintain the original order of the messages.
* The method also includes a mechanism to avoid blocking the event loop by waiting for the next tick after each iteration. * The method also includes a mechanism to avoid blocking the event loop by waiting for the next tick after each iteration.
* *
* @param {Array} messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest. * @param {Array} messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
* @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit. * @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
*/ */
async getMessagesWithinTokenLimit(messages) { async getMessagesWithinTokenLimit(messages) {
let currentTokenCount = 0; let currentTokenCount = 0;
let context = []; let context = [];
@ -282,26 +288,22 @@ class BaseClient {
context.push(message); context.push(message);
currentTokenCount = newTokenCount; currentTokenCount = newTokenCount;
remainingContextTokens = this.maxContextTokens - currentTokenCount; remainingContextTokens = this.maxContextTokens - currentTokenCount;
await new Promise(resolve => setImmediate(resolve)); await new Promise((resolve) => setImmediate(resolve));
} }
return { return {
context: context.reverse(), context: context.reverse(),
remainingContextTokens, remainingContextTokens,
messagesToRefine: messagesToRefine.reverse(), messagesToRefine: messagesToRefine.reverse(),
refineIndex refineIndex,
}; };
} }
async handleContextStrategy({ instructions, orderedMessages, formattedMessages }) { async handleContextStrategy({ instructions, orderedMessages, formattedMessages }) {
let payload = this.addInstructions(formattedMessages, instructions); let payload = this.addInstructions(formattedMessages, instructions);
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
let { let { context, remainingContextTokens, messagesToRefine, refineIndex } =
context, await this.getMessagesWithinTokenLimit(payload);
remainingContextTokens,
messagesToRefine,
refineIndex
} = await this.getMessagesWithinTokenLimit(payload);
payload = context; payload = context;
let refinedMessage; let refinedMessage;
@ -325,8 +327,14 @@ class BaseClient {
if (this.options.debug) { if (this.options.debug) {
console.debug('<---------------------------------DIFF--------------------------------->'); console.debug('<---------------------------------DIFF--------------------------------->');
console.debug(`Difference between payload (${payload.length}) and orderedWithInstructions (${orderedWithInstructions.length}): ${diff}`); console.debug(
console.debug('remainingContextTokens, this.maxContextTokens (1/2)', remainingContextTokens, this.maxContextTokens); `Difference between payload (${payload.length}) and orderedWithInstructions (${orderedWithInstructions.length}): ${diff}`,
);
console.debug(
'remainingContextTokens, this.maxContextTokens (1/2)',
remainingContextTokens,
this.maxContextTokens,
);
} }
// If the difference is positive, slice the orderedWithInstructions array // If the difference is positive, slice the orderedWithInstructions array
@ -341,7 +349,11 @@ class BaseClient {
} }
if (this.options.debug) { if (this.options.debug) {
console.debug('remainingContextTokens, this.maxContextTokens (2/2)', remainingContextTokens, this.maxContextTokens); console.debug(
'remainingContextTokens, this.maxContextTokens (2/2)',
remainingContextTokens,
this.maxContextTokens,
);
} }
let tokenCountMap = orderedWithInstructions.reduce((map, message, index) => { let tokenCountMap = orderedWithInstructions.reduce((map, message, index) => {
@ -370,20 +382,19 @@ class BaseClient {
} }
async sendMessage(message, opts = {}) { async sendMessage(message, opts = {}) {
const { const { user, conversationId, responseMessageId, saveOptions, userMessage } =
user, await this.handleStartMethods(message, opts);
conversationId,
responseMessageId,
saveOptions,
userMessage,
} = await this.handleStartMethods(message, opts);
this.user = user; this.user = user;
// It's not necessary to push to currentMessages // It's not necessary to push to currentMessages
// depending on subclass implementation of handling messages // depending on subclass implementation of handling messages
this.currentMessages.push(userMessage); this.currentMessages.push(userMessage);
let { prompt: payload, tokenCountMap, promptTokens } = await this.buildMessages( let {
prompt: payload,
tokenCountMap,
promptTokens,
} = await this.buildMessages(
this.currentMessages, this.currentMessages,
// When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId. // When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
// this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation // this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
@ -397,7 +408,7 @@ class BaseClient {
} }
if (tokenCountMap) { if (tokenCountMap) {
console.dir(tokenCountMap, { depth: null }) console.dir(tokenCountMap, { depth: null });
if (tokenCountMap[userMessage.messageId]) { if (tokenCountMap[userMessage.messageId]) {
userMessage.tokenCount = tokenCountMap[userMessage.messageId]; userMessage.tokenCount = tokenCountMap[userMessage.messageId];
console.log('userMessage.tokenCount', userMessage.tokenCount); console.log('userMessage.tokenCount', userMessage.tokenCount);
@ -461,7 +472,7 @@ class BaseClient {
await saveConvo(user, { await saveConvo(user, {
conversationId: message.conversationId, conversationId: message.conversationId,
endpoint: this.options.endpoint, endpoint: this.options.endpoint,
...endpointOptions ...endpointOptions,
}); });
} }
@ -470,12 +481,12 @@ class BaseClient {
} }
/** /**
* Iterate through messages, building an array based on the parentMessageId. * Iterate through messages, building an array based on the parentMessageId.
* Each message has an id and a parentMessageId. The parentMessageId is the id of the message that this message is a reply to. * Each message has an id and a parentMessageId. The parentMessageId is the id of the message that this message is a reply to.
* @param messages * @param messages
* @param parentMessageId * @param parentMessageId
* @returns {*[]} An array containing the messages in the order they should be displayed, starting with the root message. * @returns {*[]} An array containing the messages in the order they should be displayed, starting with the root message.
*/ */
static getMessagesForConversation(messages, parentMessageId, mapMethod = null) { static getMessagesForConversation(messages, parentMessageId, mapMethod = null) {
if (!messages || messages.length === 0) { if (!messages || messages.length === 0) {
return []; return [];
@ -484,7 +495,7 @@ class BaseClient {
const orderedMessages = []; const orderedMessages = [];
let currentMessageId = parentMessageId; let currentMessageId = parentMessageId;
while (currentMessageId) { while (currentMessageId) {
const message = messages.find(msg => { const message = messages.find((msg) => {
const messageId = msg.messageId ?? msg.id; const messageId = msg.messageId ?? msg.id;
return messageId === currentMessageId; return messageId === currentMessageId;
}); });
@ -503,13 +514,13 @@ class BaseClient {
} }
/** /**
* Algorithm adapted from "6. Counting tokens for chat API calls" of * Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb * https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
* *
* An additional 2 tokens need to be added for metadata after all messages have been counted. * An additional 2 tokens need to be added for metadata after all messages have been counted.
* *
* @param {*} message * @param {*} message
*/ */
getTokenCountForMessage(message) { getTokenCountForMessage(message) {
let tokensPerMessage; let tokensPerMessage;
let nameAdjustment; let nameAdjustment;
@ -534,7 +545,7 @@ class BaseClient {
const numTokens = this.getTokenCount(value); const numTokens = this.getTokenCount(value);
// Adjust by `nameAdjustment` tokens if the property key is 'name' // Adjust by `nameAdjustment` tokens if the property key is 'name'
const adjustment = (key === 'name') ? nameAdjustment : 0; const adjustment = key === 'name' ? nameAdjustment : 0;
return numTokens + adjustment; return numTokens + adjustment;
}); });
@ -547,4 +558,4 @@ class BaseClient {
} }
} }
module.exports = BaseClient; module.exports = BaseClient;

View file

@ -1,6 +1,9 @@
const crypto = require('crypto'); const crypto = require('crypto');
const Keyv = require('keyv'); const Keyv = require('keyv');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('@dqbd/tiktoken'); const {
encoding_for_model: encodingForModel,
get_encoding: getEncoding,
} = require('@dqbd/tiktoken');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { Agent, ProxyAgent } = require('undici'); const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
@ -9,11 +12,7 @@ const CHATGPT_MODEL = 'gpt-3.5-turbo';
const tokenizersCache = {}; const tokenizersCache = {};
class ChatGPTClient extends BaseClient { class ChatGPTClient extends BaseClient {
constructor( constructor(apiKey, options = {}, cacheOptions = {}) {
apiKey,
options = {},
cacheOptions = {},
) {
super(apiKey, options, cacheOptions); super(apiKey, options, cacheOptions);
cacheOptions.namespace = cacheOptions.namespace || 'chatgpt'; cacheOptions.namespace = cacheOptions.namespace || 'chatgpt';
@ -49,13 +48,16 @@ class ChatGPTClient extends BaseClient {
model: modelOptions.model || CHATGPT_MODEL, model: modelOptions.model || CHATGPT_MODEL,
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature, temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p, top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
presence_penalty: typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty, presence_penalty:
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
stop: modelOptions.stop, stop: modelOptions.stop,
}; };
this.isChatGptModel = this.modelOptions.model.startsWith('gpt-'); this.isChatGptModel = this.modelOptions.model.startsWith('gpt-');
const { isChatGptModel } = this; const { isChatGptModel } = this;
this.isUnofficialChatGptModel = this.modelOptions.model.startsWith('text-chat') || this.modelOptions.model.startsWith('text-davinci-002-render'); this.isUnofficialChatGptModel =
this.modelOptions.model.startsWith('text-chat') ||
this.modelOptions.model.startsWith('text-davinci-002-render');
const { isUnofficialChatGptModel } = this; const { isUnofficialChatGptModel } = this;
// Davinci models have a max context length of 4097 tokens. // Davinci models have a max context length of 4097 tokens.
@ -64,10 +66,15 @@ class ChatGPTClient extends BaseClient {
// The max prompt tokens is determined by the max context tokens minus the max response tokens. // The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit. // Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.max_tokens || 1024; this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxPromptTokens = this.options.maxPromptTokens || (this.maxContextTokens - this.maxResponseTokens); this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
throw new Error(`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${this.maxPromptTokens + this.maxResponseTokens}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`); throw new Error(
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
);
} }
this.userLabel = this.options.userLabel || 'User'; this.userLabel = this.options.userLabel || 'User';
@ -249,13 +256,10 @@ class ChatGPTClient extends BaseClient {
} }
}); });
} }
const response = await fetch( const response = await fetch(url, {
url, ...opts,
{ signal: abortController.signal,
...opts, });
signal: abortController.signal,
},
);
if (response.status !== 200) { if (response.status !== 200) {
const body = await response.text(); const body = await response.text();
const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`); const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
@ -299,10 +303,7 @@ ${botMessage.message}
.trim(); .trim();
} }
async sendMessage( async sendMessage(message, opts = {}) {
message,
opts = {},
) {
if (opts.clientOptions && typeof opts.clientOptions === 'object') { if (opts.clientOptions && typeof opts.clientOptions === 'object') {
this.setOptions(opts.clientOptions); this.setOptions(opts.clientOptions);
} }
@ -310,9 +311,10 @@ ${botMessage.message}
const conversationId = opts.conversationId || crypto.randomUUID(); const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || crypto.randomUUID(); const parentMessageId = opts.parentMessageId || crypto.randomUUID();
let conversation = typeof opts.conversation === 'object' let conversation =
? opts.conversation typeof opts.conversation === 'object'
: await this.conversationsCache.get(conversationId); ? opts.conversation
: await this.conversationsCache.get(conversationId);
let isNewConversation = false; let isNewConversation = false;
if (!conversation) { if (!conversation) {
@ -357,7 +359,9 @@ ${botMessage.message}
if (progressMessage === '[DONE]') { if (progressMessage === '[DONE]') {
return; return;
} }
const token = this.isChatGptModel ? progressMessage.choices[0].delta.content : progressMessage.choices[0].text; const token = this.isChatGptModel
? progressMessage.choices[0].delta.content
: progressMessage.choices[0].text;
// first event's delta content is always undefined // first event's delta content is always undefined
if (!token) { if (!token) {
return; return;
@ -437,10 +441,11 @@ ${botMessage.message}
} }
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`; promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
} else { } else {
const currentDateString = new Date().toLocaleDateString( const currentDateString = new Date().toLocaleDateString('en-us', {
'en-us', year: 'numeric',
{ year: 'numeric', month: 'long', day: 'numeric' }, month: 'long',
); day: 'numeric',
});
promptPrefix = `${this.startToken}Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date: ${currentDateString}${this.endToken}\n\n`; promptPrefix = `${this.startToken}Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date: ${currentDateString}${this.endToken}\n\n`;
} }
@ -459,7 +464,9 @@ ${botMessage.message}
let currentTokenCount; let currentTokenCount;
if (isChatGptModel) { if (isChatGptModel) {
currentTokenCount = this.getTokenCountForMessage(instructionsPayload) + this.getTokenCountForMessage(messagePayload); currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
} else { } else {
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`); currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
} }
@ -473,8 +480,13 @@ ${botMessage.message}
const buildPromptBody = async () => { const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) { if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
const message = orderedMessages.pop(); const message = orderedMessages.pop();
const roleLabel = message?.isCreatedByUser || message?.role?.toLowerCase() === 'user' ? this.userLabel : this.chatGptLabel; const roleLabel =
const messageString = `${this.startToken}${roleLabel}:\n${message?.text ?? message?.message}${this.endToken}\n`; message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
? this.userLabel
: this.chatGptLabel;
const messageString = `${this.startToken}${roleLabel}:\n${
message?.text ?? message?.message
}${this.endToken}\n`;
let newPromptBody; let newPromptBody;
if (promptBody || isChatGptModel) { if (promptBody || isChatGptModel) {
newPromptBody = `${messageString}${promptBody}`; newPromptBody = `${messageString}${promptBody}`;
@ -496,12 +508,14 @@ ${botMessage.message}
return false; return false;
} }
// This is the first message, so we can't add it. Just throw an error. // This is the first message, so we can't add it. Just throw an error.
throw new Error(`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`); throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
} }
promptBody = newPromptBody; promptBody = newPromptBody;
currentTokenCount = newTokenCount; currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop // wait for next tick to avoid blocking the event loop
await new Promise(resolve => setImmediate(resolve)); await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody(); return buildPromptBody();
} }
return true; return true;
@ -517,7 +531,10 @@ ${botMessage.message}
} }
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(this.maxContextTokens - currentTokenCount, this.maxResponseTokens); this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (this.options.debug) { if (this.options.debug) {
console.debug(`Prompt : ${prompt}`); console.debug(`Prompt : ${prompt}`);
@ -534,13 +551,13 @@ ${botMessage.message}
} }
/** /**
* Algorithm adapted from "6. Counting tokens for chat API calls" of * Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb * https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
* *
* An additional 2 tokens need to be added for metadata after all messages have been counted. * An additional 2 tokens need to be added for metadata after all messages have been counted.
* *
* @param {*} message * @param {*} message
*/ */
getTokenCountForMessage(message) { getTokenCountForMessage(message) {
let tokensPerMessage; let tokensPerMessage;
let nameAdjustment; let nameAdjustment;
@ -558,7 +575,7 @@ ${botMessage.message}
const numTokens = this.getTokenCount(value); const numTokens = this.getTokenCount(value);
// Adjust by `nameAdjustment` tokens if the property key is 'name' // Adjust by `nameAdjustment` tokens if the property key is 'name'
const adjustment = (key === 'name') ? nameAdjustment : 0; const adjustment = key === 'name' ? nameAdjustment : 0;
return numTokens + adjustment; return numTokens + adjustment;
}); });
@ -567,4 +584,4 @@ ${botMessage.message}
} }
} }
module.exports = ChatGPTClient; module.exports = ChatGPTClient;

View file

@ -3,7 +3,7 @@ const { google } = require('googleapis');
const { Agent, ProxyAgent } = require('undici'); const { Agent, ProxyAgent } = require('undici');
const { const {
encoding_for_model: encodingForModel, encoding_for_model: encodingForModel,
get_encoding: getEncoding get_encoding: getEncoding,
} = require('@dqbd/tiktoken'); } = require('@dqbd/tiktoken');
const tokenizersCache = {}; const tokenizersCache = {};
@ -43,20 +43,20 @@ class GoogleClient extends BaseClient {
// nested options aren't spread properly, so we need to do this manually // nested options aren't spread properly, so we need to do this manually
this.options.modelOptions = { this.options.modelOptions = {
...this.options.modelOptions, ...this.options.modelOptions,
...options.modelOptions ...options.modelOptions,
}; };
delete options.modelOptions; delete options.modelOptions;
// now we can merge options // now we can merge options
this.options = { this.options = {
...this.options, ...this.options,
...options ...options,
}; };
} else { } else {
this.options = options; this.options = options;
} }
this.options.examples = this.options.examples.filter( this.options.examples = this.options.examples.filter(
(obj) => obj.input.content !== '' && obj.output.content !== '' (obj) => obj.input.content !== '' && obj.output.content !== '',
); );
const modelOptions = this.options.modelOptions || {}; const modelOptions = this.options.modelOptions || {};
@ -66,7 +66,7 @@ class GoogleClient extends BaseClient {
model: modelOptions.model || 'chat-bison', model: modelOptions.model || 'chat-bison',
temperature: typeof modelOptions.temperature === 'undefined' ? 0.2 : modelOptions.temperature, // 0 - 1, 0.2 is recommended temperature: typeof modelOptions.temperature === 'undefined' ? 0.2 : modelOptions.temperature, // 0 - 1, 0.2 is recommended
topP: typeof modelOptions.topP === 'undefined' ? 0.95 : modelOptions.topP, // 0 - 1, default: 0.95 topP: typeof modelOptions.topP === 'undefined' ? 0.95 : modelOptions.topP, // 0 - 1, default: 0.95
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK // 1-40, default: 40 topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
// stop: modelOptions.stop // no stop method for now // stop: modelOptions.stop // no stop method for now
}; };
@ -86,7 +86,7 @@ class GoogleClient extends BaseClient {
throw new Error( throw new Error(
`maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ `maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})` }) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
); );
} }
@ -105,7 +105,7 @@ class GoogleClient extends BaseClient {
this.endToken = '<|im_end|>'; this.endToken = '<|im_end|>';
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, { this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
'<|im_start|>': 100264, '<|im_start|>': 100264,
'<|im_end|>': 100265 '<|im_end|>': 100265,
}); });
} else { } else {
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting // Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
@ -143,7 +143,7 @@ class GoogleClient extends BaseClient {
getMessageMapMethod() { getMessageMapMethod() {
return ((message) => ({ return ((message) => ({
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel), author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
content: message?.content ?? message.text content: message?.content ?? message.text,
})).bind(this); })).bind(this);
} }
@ -153,9 +153,9 @@ class GoogleClient extends BaseClient {
instances: [ instances: [
{ {
messages: formattedMessages, messages: formattedMessages,
} },
], ],
parameters: this.options.modelOptions parameters: this.options.modelOptions,
}; };
if (this.options.promptPrefix) { if (this.options.promptPrefix) {
@ -170,8 +170,8 @@ class GoogleClient extends BaseClient {
if (this.isTextModel) { if (this.isTextModel) {
payload.instances = [ payload.instances = [
{ {
prompt: messages[messages.length -1].content prompt: messages[messages.length - 1].content,
} },
]; ];
} }
@ -199,9 +199,9 @@ class GoogleClient extends BaseClient {
method: 'POST', method: 'POST',
agent: new Agent({ agent: new Agent({
bodyTimeout: 0, bodyTimeout: 0,
headersTimeout: 0 headersTimeout: 0,
}), }),
signal: abortController.signal signal: abortController.signal,
}; };
if (this.options.proxy) { if (this.options.proxy) {
@ -218,7 +218,7 @@ class GoogleClient extends BaseClient {
return { return {
promptPrefix: this.options.promptPrefix, promptPrefix: this.options.promptPrefix,
modelLabel: this.options.modelLabel, modelLabel: this.options.modelLabel,
...this.modelOptions ...this.modelOptions,
}; };
} }
@ -239,7 +239,7 @@ class GoogleClient extends BaseClient {
''; '';
if (blocked === true) { if (blocked === true) {
reply = `Google blocked a proper response to your message:\n${JSON.stringify( reply = `Google blocked a proper response to your message:\n${JSON.stringify(
result.predictions[0].safetyAttributes result.predictions[0].safetyAttributes,
)}${reply.length > 0 ? `\nAI Response:\n${reply}` : ''}`; )}${reply.length > 0 ? `\nAI Response:\n${reply}` : ''}`;
} }
if (this.options.debug) { if (this.options.debug) {

View file

@ -1,6 +1,9 @@
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const ChatGPTClient = require('./ChatGPTClient'); const ChatGPTClient = require('./ChatGPTClient');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('@dqbd/tiktoken'); const {
encoding_for_model: encodingForModel,
get_encoding: getEncoding,
} = require('@dqbd/tiktoken');
const { maxTokensMap, genAzureChatCompletion } = require('../../utils'); const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
const tokenizersCache = {}; const tokenizersCache = {};
@ -12,7 +15,9 @@ class OpenAIClient extends BaseClient {
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this); this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this); this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
this.sender = options.sender ?? 'ChatGPT'; this.sender = options.sender ?? 'ChatGPT';
this.contextStrategy = options.contextStrategy ? options.contextStrategy.toLowerCase() : 'discard'; this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase()
: 'discard';
this.shouldRefineContext = this.contextStrategy === 'refine'; this.shouldRefineContext = this.contextStrategy === 'refine';
this.azure = options.azure || false; this.azure = options.azure || false;
if (this.azure) { if (this.azure) {
@ -45,27 +50,39 @@ class OpenAIClient extends BaseClient {
this.modelOptions = { this.modelOptions = {
...modelOptions, ...modelOptions,
model: modelOptions.model || 'gpt-3.5-turbo', model: modelOptions.model || 'gpt-3.5-turbo',
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature, temperature:
typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p, top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
presence_penalty: typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty, presence_penalty:
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
stop: modelOptions.stop, stop: modelOptions.stop,
}; };
} }
this.isChatCompletion = this.options.reverseProxyUrl || this.options.localAI || this.modelOptions.model.startsWith('gpt-'); this.isChatCompletion =
this.options.reverseProxyUrl ||
this.options.localAI ||
this.modelOptions.model.startsWith('gpt-');
this.isChatGptModel = this.isChatCompletion; this.isChatGptModel = this.isChatCompletion;
if (this.modelOptions.model === 'text-davinci-003') { if (this.modelOptions.model === 'text-davinci-003') {
this.isChatCompletion = false; this.isChatCompletion = false;
this.isChatGptModel = false; this.isChatGptModel = false;
} }
const { isChatGptModel } = this; const { isChatGptModel } = this;
this.isUnofficialChatGptModel = this.modelOptions.model.startsWith('text-chat') || this.modelOptions.model.startsWith('text-davinci-002-render'); this.isUnofficialChatGptModel =
this.modelOptions.model.startsWith('text-chat') ||
this.modelOptions.model.startsWith('text-davinci-002-render');
this.maxContextTokens = maxTokensMap[this.modelOptions.model] ?? 4095; // 1 less than maximum this.maxContextTokens = maxTokensMap[this.modelOptions.model] ?? 4095; // 1 less than maximum
this.maxResponseTokens = this.modelOptions.max_tokens || 1024; this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxPromptTokens = this.options.maxPromptTokens || (this.maxContextTokens - this.maxResponseTokens); this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
throw new Error(`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${this.maxPromptTokens + this.maxResponseTokens}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`); throw new Error(
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
);
} }
this.userLabel = this.options.userLabel || 'User'; this.userLabel = this.options.userLabel || 'User';
@ -185,7 +202,7 @@ class OpenAIClient extends BaseClient {
return { return {
chatGptLabel: this.options.chatGptLabel, chatGptLabel: this.options.chatGptLabel,
promptPrefix: this.options.promptPrefix, promptPrefix: this.options.promptPrefix,
...this.modelOptions ...this.modelOptions,
}; };
} }
@ -197,9 +214,16 @@ class OpenAIClient extends BaseClient {
}; };
} }
async buildMessages(messages, parentMessageId, { isChatCompletion = false, promptPrefix = null }) { async buildMessages(
messages,
parentMessageId,
{ isChatCompletion = false, promptPrefix = null },
) {
if (!isChatCompletion) { if (!isChatCompletion) {
return await this.buildPrompt(messages, parentMessageId, { isChatGptModel: isChatCompletion, promptPrefix }); return await this.buildPrompt(messages, parentMessageId, {
isChatGptModel: isChatCompletion,
promptPrefix,
});
} }
let payload; let payload;
@ -214,7 +238,7 @@ class OpenAIClient extends BaseClient {
instructions = { instructions = {
role: 'system', role: 'system',
name: 'instructions', name: 'instructions',
content: promptPrefix content: promptPrefix,
}; };
if (this.contextStrategy) { if (this.contextStrategy) {
@ -236,7 +260,8 @@ class OpenAIClient extends BaseClient {
} }
if (this.contextStrategy) { if (this.contextStrategy) {
formattedMessage.tokenCount = message.tokenCount ?? this.getTokenCountForMessage(formattedMessage); formattedMessage.tokenCount =
message.tokenCount ?? this.getTokenCountForMessage(formattedMessage);
} }
return formattedMessage; return formattedMessage;
@ -244,8 +269,11 @@ class OpenAIClient extends BaseClient {
// TODO: need to handle interleaving instructions better // TODO: need to handle interleaving instructions better
if (this.contextStrategy) { if (this.contextStrategy) {
({ payload, tokenCountMap, promptTokens, messages } = ({ payload, tokenCountMap, promptTokens, messages } = await this.handleContextStrategy({
await this.handleContextStrategy({ instructions, orderedMessages, formattedMessages })); instructions,
orderedMessages,
formattedMessages,
}));
} }
const result = { const result = {
@ -272,8 +300,9 @@ class OpenAIClient extends BaseClient {
if (progressMessage === '[DONE]') { if (progressMessage === '[DONE]') {
return; return;
} }
const token = const token = this.isChatCompletion
this.isChatCompletion ? progressMessage.choices?.[0]?.delta?.content : progressMessage.choices?.[0]?.text; ? progressMessage.choices?.[0]?.delta?.content
: progressMessage.choices?.[0]?.text;
// first event's delta content is always undefined // first event's delta content is always undefined
if (!token) { if (!token) {
return; return;

View file

@ -5,11 +5,7 @@ const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents/')
const { loadTools } = require('./tools/util'); const { loadTools } = require('./tools/util');
const { SelfReflectionTool } = require('./tools/'); const { SelfReflectionTool } = require('./tools/');
const { HumanChatMessage, AIChatMessage } = require('langchain/schema'); const { HumanChatMessage, AIChatMessage } = require('langchain/schema');
const { const { instructions, imageInstructions, errorInstructions } = require('./prompts/instructions');
instructions,
imageInstructions,
errorInstructions,
} = require('./prompts/instructions');
class PluginsClient extends OpenAIClient { class PluginsClient extends OpenAIClient {
constructor(apiKey, options = {}) { constructor(apiKey, options = {}) {
@ -28,11 +24,13 @@ class PluginsClient extends OpenAIClient {
if (actions[0]?.action && this.functionsAgent) { if (actions[0]?.action && this.functionsAgent) {
actions = actions.map((step) => ({ actions = actions.map((step) => ({
log: `Action: ${step.action?.tool || ''}\nInput: ${JSON.stringify(step.action?.toolInput) || ''}\nObservation: ${step.observation}` log: `Action: ${step.action?.tool || ''}\nInput: ${
JSON.stringify(step.action?.toolInput) || ''
}\nObservation: ${step.observation}`,
})); }));
} else if (actions[0]?.action) { } else if (actions[0]?.action) {
actions = actions.map((step) => ({ actions = actions.map((step) => ({
log: `${step.action.log}\nObservation: ${step.observation}` log: `${step.action.log}\nObservation: ${step.observation}`,
})); }));
} }
@ -136,10 +134,10 @@ Only respond with your conversational reply to the following User Message:
const prefixMap = { const prefixMap = {
'gpt-4': 'gpt-4-0613', 'gpt-4': 'gpt-4-0613',
'gpt-4-32k': 'gpt-4-32k-0613', 'gpt-4-32k': 'gpt-4-32k-0613',
'gpt-3.5-turbo': 'gpt-3.5-turbo-0613' 'gpt-3.5-turbo': 'gpt-3.5-turbo-0613',
}; };
const prefix = Object.keys(prefixMap).find(key => input.startsWith(key)); const prefix = Object.keys(prefixMap).find((key) => input.startsWith(key));
return prefix ? prefixMap[prefix] : 'gpt-3.5-turbo-0613'; return prefix ? prefixMap[prefix] : 'gpt-3.5-turbo-0613';
} }
@ -173,7 +171,7 @@ Only respond with your conversational reply to the following User Message:
async initialize({ user, message, onAgentAction, onChainEnd, signal }) { async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
const modelOptions = { const modelOptions = {
modelName: this.agentOptions.model, modelName: this.agentOptions.model,
temperature: this.agentOptions.temperature temperature: this.agentOptions.temperature,
}; };
const configOptions = {}; const configOptions = {};
@ -194,8 +192,8 @@ Only respond with your conversational reply to the following User Message:
tools: this.options.tools, tools: this.options.tools,
functions: this.functionsAgent, functions: this.functionsAgent,
options: { options: {
openAIApiKey: this.openAIApiKey openAIApiKey: this.openAIApiKey,
} },
}); });
// load tools // load tools
for (const tool of this.options.tools) { for (const tool of this.options.tools) {
@ -235,10 +233,13 @@ Only respond with your conversational reply to the following User Message:
}; };
// Map Messages to Langchain format // Map Messages to Langchain format
const pastMessages = this.currentMessages.slice(0, -1).map( const pastMessages = this.currentMessages
msg => msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user' .slice(0, -1)
? new HumanChatMessage(msg.text) .map((msg) =>
: new AIChatMessage(msg.text)); msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanChatMessage(msg.text)
: new AIChatMessage(msg.text),
);
// initialize agent // initialize agent
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent; const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
@ -258,8 +259,8 @@ Only respond with your conversational reply to the following User Message:
if (typeof onChainEnd === 'function') { if (typeof onChainEnd === 'function') {
onChainEnd(action); onChainEnd(action);
} }
} },
}) }),
}); });
if (this.options.debug) { if (this.options.debug) {
@ -304,7 +305,7 @@ Only respond with your conversational reply to the following User Message:
return; return;
} }
intermediateSteps.forEach(step => { intermediateSteps.forEach((step) => {
const { observation } = step; const { observation } = step;
if (!observation || !observation.includes('![')) { if (!observation || !observation.includes('![')) {
return; return;
@ -346,7 +347,12 @@ Only respond with your conversational reply to the following User Message:
this.currentMessages.push(userMessage); this.currentMessages.push(userMessage);
let { prompt: payload, tokenCountMap, promptTokens, messages } = await this.buildMessages( let {
prompt: payload,
tokenCountMap,
promptTokens,
messages,
} = await this.buildMessages(
this.currentMessages, this.currentMessages,
userMessage.messageId, userMessage.messageId,
this.getBuildMessagesOptions({ this.getBuildMessagesOptions({
@ -356,7 +362,7 @@ Only respond with your conversational reply to the following User Message:
); );
if (tokenCountMap) { if (tokenCountMap) {
console.dir(tokenCountMap, { depth: null }) console.dir(tokenCountMap, { depth: null });
if (tokenCountMap[userMessage.messageId]) { if (tokenCountMap[userMessage.messageId]) {
userMessage.tokenCount = tokenCountMap[userMessage.messageId]; userMessage.tokenCount = tokenCountMap[userMessage.messageId];
console.log('userMessage.tokenCount', userMessage.tokenCount); console.log('userMessage.tokenCount', userMessage.tokenCount);
@ -389,7 +395,7 @@ Only respond with your conversational reply to the following User Message:
message, message,
onAgentAction, onAgentAction,
onChainEnd, onChainEnd,
signal: this.abortController.signal signal: this.abortController.signal,
}); });
await this.executorCall(message, this.abortController.signal); await this.executorCall(message, this.abortController.signal);
@ -448,12 +454,12 @@ Only respond with your conversational reply to the following User Message:
const instructionsPayload = { const instructionsPayload = {
role: 'system', role: 'system',
name: 'instructions', name: 'instructions',
content: promptPrefix content: promptPrefix,
}; };
const messagePayload = { const messagePayload = {
role: 'system', role: 'system',
content: promptSuffix content: promptSuffix,
}; };
if (this.isGpt3) { if (this.isGpt3) {
@ -468,8 +474,8 @@ Only respond with your conversational reply to the following User Message:
} }
let currentTokenCount = let currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) + this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload); this.getTokenCountForMessage(messagePayload);
let promptBody = ''; let promptBody = '';
const maxTokenCount = this.maxPromptTokens; const maxTokenCount = this.maxPromptTokens;
@ -492,7 +498,7 @@ Only respond with your conversational reply to the following User Message:
} }
// This is the first message, so we can't add it. Just throw an error. // This is the first message, so we can't add it. Just throw an error.
throw new Error( throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.` `Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
); );
} }
promptBody = newPromptBody; promptBody = newPromptBody;
@ -519,7 +525,7 @@ Only respond with your conversational reply to the following User Message:
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min( this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount, this.maxContextTokens - currentTokenCount,
this.maxResponseTokens this.maxResponseTokens,
); );
if (this.isGpt3) { if (this.isGpt3) {

View file

@ -8,7 +8,7 @@ class CustomAgent extends ZeroShotAgent {
} }
_stop() { _stop() {
return [`\nObservation:`, `\nObservation 1:`]; return ['\nObservation:', '\nObservation 1:'];
} }
static createPrompt(tools, opts = {}) { static createPrompt(tools, opts = {}) {
@ -32,17 +32,17 @@ class CustomAgent extends ZeroShotAgent {
.join('\n'); .join('\n');
const toolNames = tools.map((tool) => tool.name); const toolNames = tools.map((tool) => tool.name);
const formatInstructions = (0, renderTemplate)(instructions, 'f-string', { const formatInstructions = (0, renderTemplate)(instructions, 'f-string', {
tool_names: toolNames tool_names: toolNames,
}); });
const template = [ const template = [
`Date: ${currentDateString}\n${prefix}`, `Date: ${currentDateString}\n${prefix}`,
toolStrings, toolStrings,
formatInstructions, formatInstructions,
suffix suffix,
].join('\n\n'); ].join('\n\n');
return new PromptTemplate({ return new PromptTemplate({
template, template,
inputVariables inputVariables,
}); });
} }
} }

View file

@ -6,7 +6,7 @@ const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { const {
ChatPromptTemplate, ChatPromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
HumanMessagePromptTemplate HumanMessagePromptTemplate,
} = require('langchain/prompts'); } = require('langchain/prompts');
const initializeCustomAgent = async ({ const initializeCustomAgent = async ({
@ -22,7 +22,7 @@ const initializeCustomAgent = async ({
new SystemMessagePromptTemplate(prompt), new SystemMessagePromptTemplate(prompt),
HumanMessagePromptTemplate.fromTemplate(`{chat_history} HumanMessagePromptTemplate.fromTemplate(`{chat_history}
Query: {input} Query: {input}
{agent_scratchpad}`) {agent_scratchpad}`),
]); ]);
const outputParser = new CustomOutputParser({ tools }); const outputParser = new CustomOutputParser({ tools });
@ -34,18 +34,18 @@ Query: {input}
humanPrefix: 'User', humanPrefix: 'User',
aiPrefix: 'Assistant', aiPrefix: 'Assistant',
inputKey: 'input', inputKey: 'input',
outputKey: 'output' outputKey: 'output',
}); });
const llmChain = new LLMChain({ const llmChain = new LLMChain({
prompt: chatPrompt, prompt: chatPrompt,
llm: model llm: model,
}); });
const agent = new CustomAgent({ const agent = new CustomAgent({
llmChain, llmChain,
outputParser, outputParser,
allowedTools: tools.map((tool) => tool.name) allowedTools: tools.map((tool) => tool.name),
}); });
return AgentExecutor.fromAgentAndTools({ agent, tools, memory, ...rest }); return AgentExecutor.fromAgentAndTools({ agent, tools, memory, ...rest });

View file

@ -57,7 +57,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
const output = text.substring(finalMatch.index + finalMatch[0].length).trim(); const output = text.substring(finalMatch.index + finalMatch[0].length).trim();
return { return {
returnValues: { output }, returnValues: { output },
log: text log: text,
}; };
} }
@ -66,7 +66,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
if (!match) { if (!match) {
console.log( console.log(
'\n\n<----------------------HIT NO MATCH PARSING ERROR---------------------->\n\n', '\n\n<----------------------HIT NO MATCH PARSING ERROR---------------------->\n\n',
match match,
); );
const thoughts = text.replace(/[tT]hought:/, '').split('\n'); const thoughts = text.replace(/[tT]hought:/, '').split('\n');
// return { // return {
@ -77,7 +77,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
return { return {
returnValues: { output: thoughts[0] }, returnValues: { output: thoughts[0] },
log: thoughts.slice(1).join('\n') log: thoughts.slice(1).join('\n'),
}; };
} }
@ -86,12 +86,12 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
if (match && selectedTool === 'n/a') { if (match && selectedTool === 'n/a') {
console.log( console.log(
'\n\n<----------------------HIT N/A PARSING ERROR---------------------->\n\n', '\n\n<----------------------HIT N/A PARSING ERROR---------------------->\n\n',
match match,
); );
return { return {
tool: 'self-reflection', tool: 'self-reflection',
toolInput: match[2]?.trim().replace(/^"+|"+$/g, '') ?? '', toolInput: match[2]?.trim().replace(/^"+|"+$/g, '') ?? '',
log: text log: text,
}; };
} }
@ -99,7 +99,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
if (match && !toolIsValid) { if (match && !toolIsValid) {
console.log( console.log(
'\n\n<----------------Tool invalid: Re-assigning Selected Tool---------------->\n\n', '\n\n<----------------Tool invalid: Re-assigning Selected Tool---------------->\n\n',
match match,
); );
selectedTool = this.getValidTool(selectedTool); selectedTool = this.getValidTool(selectedTool);
} }
@ -107,7 +107,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
if (match && !selectedTool) { if (match && !selectedTool) {
console.log( console.log(
'\n\n<----------------------HIT INVALID TOOL PARSING ERROR---------------------->\n\n', '\n\n<----------------------HIT INVALID TOOL PARSING ERROR---------------------->\n\n',
match match,
); );
selectedTool = 'self-reflection'; selectedTool = 'self-reflection';
} }
@ -115,7 +115,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
if (match && !match[2]) { if (match && !match[2]) {
console.log( console.log(
'\n\n<----------------------HIT NO ACTION INPUT PARSING ERROR---------------------->\n\n', '\n\n<----------------------HIT NO ACTION INPUT PARSING ERROR---------------------->\n\n',
match match,
); );
// In case there is no action input, let's double-check if there is an action input in 'text' variable // In case there is no action input, let's double-check if there is an action input in 'text' variable
@ -125,7 +125,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
return { return {
tool: selectedTool, tool: selectedTool,
toolInput: actionInputMatch[1].trim(), toolInput: actionInputMatch[1].trim(),
log: text log: text,
}; };
} }
@ -133,7 +133,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
return { return {
tool: selectedTool, tool: selectedTool,
toolInput: thoughtMatch[1].trim(), toolInput: thoughtMatch[1].trim(),
log: text log: text,
}; };
} }
} }
@ -158,12 +158,12 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
if (action && actionInputMatch) { if (action && actionInputMatch) {
console.log( console.log(
'\n\n<------Matched Action Input in Long Parsing Error------>\n\n', '\n\n<------Matched Action Input in Long Parsing Error------>\n\n',
actionInputMatch actionInputMatch,
); );
return { return {
tool: action, tool: action,
toolInput: actionInputMatch[1].trim().replaceAll('"', ''), toolInput: actionInputMatch[1].trim().replaceAll('"', ''),
log: text log: text,
}; };
} }
@ -180,7 +180,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
const returnValues = { const returnValues = {
tool: action, tool: action,
toolInput: input, toolInput: input,
log: thought || inputText log: thought || inputText,
}; };
const inputMatch = this.actionValues.exec(returnValues.log); //new const inputMatch = this.actionValues.exec(returnValues.log); //new
@ -197,7 +197,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
return { return {
tool: 'self-reflection', tool: 'self-reflection',
toolInput: 'Hypothetical actions: \n"' + text + '"\n', toolInput: 'Hypothetical actions: \n"' + text + '"\n',
log: 'Thought: I need to look at my hypothetical actions and try one' log: 'Thought: I need to look at my hypothetical actions and try one',
}; };
} }
@ -210,7 +210,7 @@ class CustomOutputParser extends ZeroShotAgentOutputParser {
return { return {
tool: selectedTool, tool: selectedTool,
toolInput: match[2]?.trim()?.replace(/^"+|"+$/g, '') ?? '', toolInput: match[2]?.trim()?.replace(/^"+|"+$/g, '') ?? '',
log: text log: text,
}; };
} }
} }

View file

@ -5,9 +5,9 @@ const {
ChatPromptTemplate, ChatPromptTemplate,
MessagesPlaceholder, MessagesPlaceholder,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
HumanMessagePromptTemplate HumanMessagePromptTemplate,
} = require('langchain/prompts'); } = require('langchain/prompts');
const PREFIX = `You are a helpful AI assistant.`; const PREFIX = 'You are a helpful AI assistant.';
function parseOutput(message) { function parseOutput(message) {
if (message.additional_kwargs.function_call) { if (message.additional_kwargs.function_call) {
@ -15,7 +15,7 @@ function parseOutput(message) {
return { return {
tool: function_call.name, tool: function_call.name,
toolInput: function_call.arguments ? JSON.parse(function_call.arguments) : {}, toolInput: function_call.arguments ? JSON.parse(function_call.arguments) : {},
log: message.text log: message.text,
}; };
} else { } else {
return { returnValues: { output: message.text }, log: message.text }; return { returnValues: { output: message.text }, log: message.text };
@ -52,7 +52,7 @@ class FunctionsAgent extends Agent {
return ChatPromptTemplate.fromPromptMessages([ return ChatPromptTemplate.fromPromptMessages([
SystemMessagePromptTemplate.fromTemplate(`Date: ${currentDateString}\n${prefix}`), SystemMessagePromptTemplate.fromTemplate(`Date: ${currentDateString}\n${prefix}`),
new MessagesPlaceholder('chat_history'), new MessagesPlaceholder('chat_history'),
HumanMessagePromptTemplate.fromTemplate(`Query: {input}`), HumanMessagePromptTemplate.fromTemplate('Query: {input}'),
new MessagesPlaceholder('agent_scratchpad'), new MessagesPlaceholder('agent_scratchpad'),
]); ]);
} }
@ -63,12 +63,12 @@ class FunctionsAgent extends Agent {
const chain = new LLMChain({ const chain = new LLMChain({
prompt, prompt,
llm, llm,
callbacks: args?.callbacks callbacks: args?.callbacks,
}); });
return new FunctionsAgent({ return new FunctionsAgent({
llmChain: chain, llmChain: chain,
allowedTools: tools.map((t) => t.name), allowedTools: tools.map((t) => t.name),
tools tools,
}); });
} }
@ -77,10 +77,10 @@ class FunctionsAgent extends Agent {
new AIChatMessage('', { new AIChatMessage('', {
function_call: { function_call: {
name: action.tool, name: action.tool,
arguments: JSON.stringify(action.toolInput) arguments: JSON.stringify(action.toolInput),
} },
}), }),
new FunctionChatMessage(observation, action.tool) new FunctionChatMessage(observation, action.tool),
]); ]);
} }
@ -96,7 +96,7 @@ class FunctionsAgent extends Agent {
const llm = this.llmChain.llm; const llm = this.llmChain.llm;
const valuesForPrompt = Object.assign({}, newInputs); const valuesForPrompt = Object.assign({}, newInputs);
const valuesForLLM = { const valuesForLLM = {
tools: this.tools tools: this.tools,
}; };
for (let i = 0; i < this.llmChain.llm.callKeys.length; i++) { for (let i = 0; i < this.llmChain.llm.callKeys.length; i++) {
const key = this.llmChain.llm.callKeys[i]; const key = this.llmChain.llm.callKeys[i];
@ -110,7 +110,7 @@ class FunctionsAgent extends Agent {
const message = await llm.predictMessages( const message = await llm.predictMessages(
promptValue.toChatMessages(), promptValue.toChatMessages(),
valuesForLLM, valuesForLLM,
callbackManager callbackManager,
); );
console.log('message', message); console.log('message', message);
return parseOutput(message); return parseOutput(message);

View file

@ -8,7 +8,6 @@ const initializeFunctionsAgent = async ({
// currentDateString, // currentDateString,
...rest ...rest
}) => { }) => {
const memory = new BufferMemory({ const memory = new BufferMemory({
chatHistory: new ChatMessageHistory(pastMessages), chatHistory: new ChatMessageHistory(pastMessages),
memoryKey: 'chat_history', memoryKey: 'chat_history',
@ -19,17 +18,11 @@ const initializeFunctionsAgent = async ({
returnMessages: true, returnMessages: true,
}); });
return await initializeAgentExecutorWithOptions( return await initializeAgentExecutorWithOptions(tools, model, {
tools, agentType: 'openai-functions',
model, memory,
{ ...rest,
agentType: 'openai-functions', });
memory,
...rest,
}
);
}; };
module.exports = initializeFunctionsAgent; module.exports = initializeFunctionsAgent;

View file

@ -3,5 +3,5 @@ const initializeFunctionsAgent = require('./Functions/initializeFunctionsAgent')
module.exports = { module.exports = {
initializeCustomAgent, initializeCustomAgent,
initializeFunctionsAgent initializeFunctionsAgent,
}; };

View file

@ -13,5 +13,5 @@ module.exports = {
GoogleClient, GoogleClient,
TextStream, TextStream,
AnthropicClient, AnthropicClient,
...toolUtils ...toolUtils,
}; };

View file

@ -1,6 +1,10 @@
module.exports = { module.exports = {
instructions: `Remember, all your responses MUST be in the format described. Do not respond unless it's in the format described, using the structure of Action, Action Input, etc.`, instructions:
errorInstructions: `\nYou encountered an error in attempting a response. The user is not aware of the error so you shouldn't mention it.\nReview the actions taken carefully in case there is a partial or complete answer within them.\nError Message:`, 'Remember, all your responses MUST be in the format described. Do not respond unless it\'s in the format described, using the structure of Action, Action Input, etc.',
imageInstructions: 'You must include the exact image paths from above, formatted in Markdown syntax: ![alt-text](URL)', errorInstructions:
completionInstructions: `Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date:`, '\nYou encountered an error in attempting a response. The user is not aware of the error so you shouldn\'t mention it.\nReview the actions taken carefully in case there is a partial or complete answer within them.\nError Message:',
imageInstructions:
'You must include the exact image paths from above, formatted in Markdown syntax: ![alt-text](URL)',
completionInstructions:
'Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date:',
}; };

View file

@ -16,9 +16,9 @@ REFINED CONVERSATION SUMMARY:`;
const refinePrompt = new PromptTemplate({ const refinePrompt = new PromptTemplate({
template: refinePromptTemplate, template: refinePromptTemplate,
inputVariables: ["existing_answer", "text"], inputVariables: ['existing_answer', 'text'],
}); });
module.exports = { module.exports = {
refinePrompt, refinePrompt,
}; };

View file

@ -10,7 +10,7 @@ jest.mock('../../../models', () => {
getMessages: jest.fn(), getMessages: jest.fn(),
saveMessage: jest.fn(), saveMessage: jest.fn(),
updateMessage: jest.fn(), updateMessage: jest.fn(),
saveConvo: jest.fn() saveConvo: jest.fn(),
}; };
}; };
}); });
@ -52,7 +52,7 @@ describe('BaseClient', () => {
modelOptions: { modelOptions: {
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
temperature: 0, temperature: 0,
} },
}; };
beforeEach(() => { beforeEach(() => {
@ -60,22 +60,14 @@ describe('BaseClient', () => {
}); });
test('returns the input messages without instructions when addInstructions() is called with empty instructions', () => { test('returns the input messages without instructions when addInstructions() is called with empty instructions', () => {
const messages = [ const messages = [{ content: 'Hello' }, { content: 'How are you?' }, { content: 'Goodbye' }];
{ content: 'Hello' },
{ content: 'How are you?' },
{ content: 'Goodbye' },
];
const instructions = ''; const instructions = '';
const result = TestClient.addInstructions(messages, instructions); const result = TestClient.addInstructions(messages, instructions);
expect(result).toEqual(messages); expect(result).toEqual(messages);
}); });
test('returns the input messages with instructions properly added when addInstructions() is called with non-empty instructions', () => { test('returns the input messages with instructions properly added when addInstructions() is called with non-empty instructions', () => {
const messages = [ const messages = [{ content: 'Hello' }, { content: 'How are you?' }, { content: 'Goodbye' }];
{ content: 'Hello' },
{ content: 'How are you?' },
{ content: 'Goodbye' },
];
const instructions = { content: 'Please respond to the question.' }; const instructions = { content: 'Please respond to the question.' };
const result = TestClient.addInstructions(messages, instructions); const result = TestClient.addInstructions(messages, instructions);
const expected = [ const expected = [
@ -94,20 +86,21 @@ describe('BaseClient', () => {
{ name: 'User', content: 'I have a question.' }, { name: 'User', content: 'I have a question.' },
]; ];
const result = TestClient.concatenateMessages(messages); const result = TestClient.concatenateMessages(messages);
const expected = `User:\nHello\n\nAssistant:\nHow can I help you?\n\nUser:\nI have a question.\n\n`; const expected =
'User:\nHello\n\nAssistant:\nHow can I help you?\n\nUser:\nI have a question.\n\n';
expect(result).toBe(expected); expect(result).toBe(expected);
}); });
test('refines messages correctly in refineMessages()', async () => { test('refines messages correctly in refineMessages()', async () => {
const messagesToRefine = [ const messagesToRefine = [
{ role: 'user', content: 'Hello', tokenCount: 10 }, { role: 'user', content: 'Hello', tokenCount: 10 },
{ role: 'assistant', content: 'How can I help you?', tokenCount: 20 } { role: 'assistant', content: 'How can I help you?', tokenCount: 20 },
]; ];
const remainingContextTokens = 100; const remainingContextTokens = 100;
const expectedRefinedMessage = { const expectedRefinedMessage = {
role: 'assistant', role: 'assistant',
content: 'Refined answer', content: 'Refined answer',
tokenCount: 14 // 'Refined answer'.length tokenCount: 14, // 'Refined answer'.length
}; };
const result = await TestClient.refineMessages(messagesToRefine, remainingContextTokens); const result = await TestClient.refineMessages(messagesToRefine, remainingContextTokens);
@ -120,7 +113,7 @@ describe('BaseClient', () => {
TestClient.refineMessages = jest.fn().mockResolvedValue({ TestClient.refineMessages = jest.fn().mockResolvedValue({
role: 'assistant', role: 'assistant',
content: 'Refined answer', content: 'Refined answer',
tokenCount: 30 tokenCount: 30,
}); });
const messages = [ const messages = [
@ -148,7 +141,7 @@ describe('BaseClient', () => {
TestClient.refineMessages = jest.fn().mockResolvedValue({ TestClient.refineMessages = jest.fn().mockResolvedValue({
role: 'assistant', role: 'assistant',
content: 'Refined answer', content: 'Refined answer',
tokenCount: 4 tokenCount: 4,
}); });
const messages = [ const messages = [
@ -176,28 +169,28 @@ describe('BaseClient', () => {
}); });
test('handles context strategy correctly in handleContextStrategy()', async () => { test('handles context strategy correctly in handleContextStrategy()', async () => {
TestClient.addInstructions = jest.fn().mockReturnValue([ TestClient.addInstructions = jest
{ content: 'Hello' }, .fn()
{ content: 'How can I help you?' }, .mockReturnValue([
{ content: 'Please provide more details.' }, { content: 'Hello' },
{ content: 'I can assist you with that.' } { content: 'How can I help you?' },
]); { content: 'Please provide more details.' },
{ content: 'I can assist you with that.' },
]);
TestClient.getMessagesWithinTokenLimit = jest.fn().mockReturnValue({ TestClient.getMessagesWithinTokenLimit = jest.fn().mockReturnValue({
context: [ context: [
{ content: 'How can I help you?' }, { content: 'How can I help you?' },
{ content: 'Please provide more details.' }, { content: 'Please provide more details.' },
{ content: 'I can assist you with that.' } { content: 'I can assist you with that.' },
], ],
remainingContextTokens: 80, remainingContextTokens: 80,
messagesToRefine: [ messagesToRefine: [{ content: 'Hello' }],
{ content: 'Hello' },
],
refineIndex: 3, refineIndex: 3,
}); });
TestClient.refineMessages = jest.fn().mockResolvedValue({ TestClient.refineMessages = jest.fn().mockResolvedValue({
role: 'assistant', role: 'assistant',
content: 'Refined answer', content: 'Refined answer',
tokenCount: 30 tokenCount: 30,
}); });
TestClient.getTokenCountForResponse = jest.fn().mockReturnValue(40); TestClient.getTokenCountForResponse = jest.fn().mockReturnValue(40);
@ -206,24 +199,24 @@ describe('BaseClient', () => {
{ content: 'Hello' }, { content: 'Hello' },
{ content: 'How can I help you?' }, { content: 'How can I help you?' },
{ content: 'Please provide more details.' }, { content: 'Please provide more details.' },
{ content: 'I can assist you with that.' } { content: 'I can assist you with that.' },
]; ];
const formattedMessages = [ const formattedMessages = [
{ content: 'Hello' }, { content: 'Hello' },
{ content: 'How can I help you?' }, { content: 'How can I help you?' },
{ content: 'Please provide more details.' }, { content: 'Please provide more details.' },
{ content: 'I can assist you with that.' } { content: 'I can assist you with that.' },
]; ];
const expectedResult = { const expectedResult = {
payload: [ payload: [
{ {
content: 'Refined answer', content: 'Refined answer',
role: 'assistant', role: 'assistant',
tokenCount: 30 tokenCount: 30,
}, },
{ content: 'How can I help you?' }, { content: 'How can I help you?' },
{ content: 'Please provide more details.' }, { content: 'Please provide more details.' },
{ content: 'I can assist you with that.' } { content: 'I can assist you with that.' },
], ],
promptTokens: expect.any(Number), promptTokens: expect.any(Number),
tokenCountMap: {}, tokenCountMap: {},
@ -246,7 +239,7 @@ describe('BaseClient', () => {
isCreatedByUser: false, isCreatedByUser: false,
messageId: expect.any(String), messageId: expect.any(String),
parentMessageId: expect.any(String), parentMessageId: expect.any(String),
conversationId: expect.any(String) conversationId: expect.any(String),
}); });
const response = await TestClient.sendMessage(userMessage); const response = await TestClient.sendMessage(userMessage);
@ -261,7 +254,7 @@ describe('BaseClient', () => {
conversationId, conversationId,
parentMessageId, parentMessageId,
getIds: jest.fn(), getIds: jest.fn(),
onStart: jest.fn() onStart: jest.fn(),
}; };
const expectedResult = expect.objectContaining({ const expectedResult = expect.objectContaining({
@ -270,7 +263,7 @@ describe('BaseClient', () => {
isCreatedByUser: false, isCreatedByUser: false,
messageId: expect.any(String), messageId: expect.any(String),
parentMessageId: expect.any(String), parentMessageId: expect.any(String),
conversationId: opts.conversationId conversationId: opts.conversationId,
}); });
const response = await TestClient.sendMessage(userMessage, opts); const response = await TestClient.sendMessage(userMessage, opts);
@ -300,7 +293,10 @@ describe('BaseClient', () => {
test('loadHistory is called with the correct arguments', async () => { test('loadHistory is called with the correct arguments', async () => {
const opts = { conversationId: '123', parentMessageId: '456' }; const opts = { conversationId: '123', parentMessageId: '456' };
await TestClient.sendMessage('Hello, world!', opts); await TestClient.sendMessage('Hello, world!', opts);
expect(TestClient.loadHistory).toHaveBeenCalledWith(opts.conversationId, opts.parentMessageId); expect(TestClient.loadHistory).toHaveBeenCalledWith(
opts.conversationId,
opts.parentMessageId,
);
}); });
test('getIds is called with the correct arguments', async () => { test('getIds is called with the correct arguments', async () => {
@ -310,7 +306,7 @@ describe('BaseClient', () => {
expect(getIds).toHaveBeenCalledWith({ expect(getIds).toHaveBeenCalledWith({
userMessage: expect.objectContaining({ text: 'Hello, world!' }), userMessage: expect.objectContaining({ text: 'Hello, world!' }),
conversationId: response.conversationId, conversationId: response.conversationId,
responseMessageId: response.messageId responseMessageId: response.messageId,
}); });
}); });
@ -333,10 +329,10 @@ describe('BaseClient', () => {
isCreatedByUser: expect.any(Boolean), isCreatedByUser: expect.any(Boolean),
messageId: expect.any(String), messageId: expect.any(String),
parentMessageId: expect.any(String), parentMessageId: expect.any(String),
conversationId: expect.any(String) conversationId: expect.any(String),
}), }),
saveOptions, saveOptions,
user user,
); );
}); });
@ -358,14 +354,16 @@ describe('BaseClient', () => {
test('returns an object with the correct shape', async () => { test('returns an object with the correct shape', async () => {
const response = await TestClient.sendMessage('Hello, world!', {}); const response = await TestClient.sendMessage('Hello, world!', {});
expect(response).toEqual(expect.objectContaining({ expect(response).toEqual(
sender: expect.any(String), expect.objectContaining({
text: expect.any(String), sender: expect.any(String),
isCreatedByUser: expect.any(Boolean), text: expect.any(String),
messageId: expect.any(String), isCreatedByUser: expect.any(Boolean),
parentMessageId: expect.any(String), messageId: expect.any(String),
conversationId: expect.any(String) parentMessageId: expect.any(String),
})); conversationId: expect.any(String),
}),
);
}); });
}); });
}); });

View file

@ -32,9 +32,11 @@ class FakeClient extends BaseClient {
this.modelOptions = { this.modelOptions = {
...modelOptions, ...modelOptions,
model: modelOptions.model || 'gpt-3.5-turbo', model: modelOptions.model || 'gpt-3.5-turbo',
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature, temperature:
typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p, top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
presence_penalty: typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty, presence_penalty:
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
stop: modelOptions.stop, stop: modelOptions.stop,
}; };
} }
@ -66,7 +68,7 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
const orderedMessages = TestClient.constructor.getMessagesForConversation( const orderedMessages = TestClient.constructor.getMessagesForConversation(
fakeMessages, fakeMessages,
parentMessageId parentMessageId,
); );
TestClient.currentMessages = orderedMessages; TestClient.currentMessages = orderedMessages;
@ -98,7 +100,7 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
this.pastMessages = await TestClient.loadHistory( this.pastMessages = await TestClient.loadHistory(
conversationId, conversationId,
TestClient.options?.parentMessageId TestClient.options?.parentMessageId,
); );
const userMessage = { const userMessage = {
@ -107,7 +109,7 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
isCreatedByUser: true, isCreatedByUser: true,
messageId: userMessageId, messageId: userMessageId,
parentMessageId, parentMessageId,
conversationId conversationId,
}; };
const response = { const response = {
@ -116,7 +118,7 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
isCreatedByUser: false, isCreatedByUser: false,
messageId: crypto.randomUUID(), messageId: crypto.randomUUID(),
parentMessageId: userMessage.messageId, parentMessageId: userMessage.messageId,
conversationId conversationId,
}; };
fakeMessages.push(userMessage); fakeMessages.push(userMessage);
@ -126,7 +128,7 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
opts.getIds({ opts.getIds({
userMessage, userMessage,
conversationId, conversationId,
responseMessageId: response.messageId responseMessageId: response.messageId,
}); });
} }
@ -146,7 +148,10 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
// userMessage is always the last one in the payload // userMessage is always the last one in the payload
if (i === payload.length - 1) { if (i === payload.length - 1) {
userMessage.tokenCount = message.tokenCount; userMessage.tokenCount = message.tokenCount;
console.debug(`Token count for user message: ${tokenCount}`, `Instruction Tokens: ${tokenCountMap.instructions || 'N/A'}`); console.debug(
`Token count for user message: ${tokenCount}`,
`Instruction Tokens: ${tokenCountMap.instructions || 'N/A'}`,
);
} }
return messageWithoutTokenCount; return messageWithoutTokenCount;
}); });
@ -163,7 +168,10 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
}); });
TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => { TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
const orderedMessages = TestClient.constructor.getMessagesForConversation(messages, parentMessageId); const orderedMessages = TestClient.constructor.getMessagesForConversation(
messages,
parentMessageId,
);
const formattedMessages = orderedMessages.map((message) => { const formattedMessages = orderedMessages.map((message) => {
let { role: _role, sender, text } = message; let { role: _role, sender, text } = message;
const role = _role ?? sender; const role = _role ?? sender;
@ -180,6 +188,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
}); });
return TestClient; return TestClient;
} };
module.exports = { FakeClient, initializeFakeClient }; module.exports = { FakeClient, initializeFakeClient };

View file

@ -5,7 +5,7 @@ describe('OpenAIClient', () => {
const model = 'gpt-4'; const model = 'gpt-4';
const parentMessageId = '1'; const parentMessageId = '1';
const messages = [ const messages = [
{ role: 'user', sender: 'User', text: 'Hello', messageId: parentMessageId}, { role: 'user', sender: 'User', text: 'Hello', messageId: parentMessageId },
{ role: 'assistant', sender: 'Assistant', text: 'Hi', messageId: '2' }, { role: 'assistant', sender: 'Assistant', text: 'Hi', messageId: '2' },
]; ];
@ -22,7 +22,7 @@ describe('OpenAIClient', () => {
client.refineMessages = jest.fn().mockResolvedValue({ client.refineMessages = jest.fn().mockResolvedValue({
role: 'assistant', role: 'assistant',
content: 'Refined answer', content: 'Refined answer',
tokenCount: 30 tokenCount: 30,
}); });
}); });
@ -100,60 +100,83 @@ describe('OpenAIClient', () => {
describe('buildMessages', () => { describe('buildMessages', () => {
it('should build messages correctly for chat completion', async () => { it('should build messages correctly for chat completion', async () => {
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: true }); const result = await client.buildMessages(messages, parentMessageId, {
isChatCompletion: true,
});
expect(result).toHaveProperty('prompt'); expect(result).toHaveProperty('prompt');
}); });
it('should build messages correctly for non-chat completion', async () => { it('should build messages correctly for non-chat completion', async () => {
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: false }); const result = await client.buildMessages(messages, parentMessageId, {
isChatCompletion: false,
});
expect(result).toHaveProperty('prompt'); expect(result).toHaveProperty('prompt');
}); });
it('should build messages correctly with a promptPrefix', async () => { it('should build messages correctly with a promptPrefix', async () => {
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: true, promptPrefix: 'Test Prefix' }); const result = await client.buildMessages(messages, parentMessageId, {
isChatCompletion: true,
promptPrefix: 'Test Prefix',
});
expect(result).toHaveProperty('prompt'); expect(result).toHaveProperty('prompt');
const instructions = result.prompt.find(item => item.name === 'instructions'); const instructions = result.prompt.find((item) => item.name === 'instructions');
expect(instructions).toBeDefined(); expect(instructions).toBeDefined();
expect(instructions.content).toContain('Test Prefix'); expect(instructions.content).toContain('Test Prefix');
}); });
it('should handle context strategy correctly', async () => { it('should handle context strategy correctly', async () => {
client.contextStrategy = 'refine'; client.contextStrategy = 'refine';
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: true }); const result = await client.buildMessages(messages, parentMessageId, {
isChatCompletion: true,
});
expect(result).toHaveProperty('prompt'); expect(result).toHaveProperty('prompt');
expect(result).toHaveProperty('tokenCountMap'); expect(result).toHaveProperty('tokenCountMap');
}); });
it('should assign name property for user messages when options.name is set', async () => { it('should assign name property for user messages when options.name is set', async () => {
client.options.name = 'Test User'; client.options.name = 'Test User';
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: true }); const result = await client.buildMessages(messages, parentMessageId, {
const hasUserWithName = result.prompt.some(item => item.role === 'user' && item.name === 'Test User'); isChatCompletion: true,
});
const hasUserWithName = result.prompt.some(
(item) => item.role === 'user' && item.name === 'Test User',
);
expect(hasUserWithName).toBe(true); expect(hasUserWithName).toBe(true);
}); });
it('should calculate tokenCount for each message when contextStrategy is set', async () => { it('should calculate tokenCount for each message when contextStrategy is set', async () => {
client.contextStrategy = 'refine'; client.contextStrategy = 'refine';
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: true }); const result = await client.buildMessages(messages, parentMessageId, {
const hasUserWithTokenCount = result.prompt.some(item => item.role === 'user' && item.tokenCount > 0); isChatCompletion: true,
});
const hasUserWithTokenCount = result.prompt.some(
(item) => item.role === 'user' && item.tokenCount > 0,
);
expect(hasUserWithTokenCount).toBe(true); expect(hasUserWithTokenCount).toBe(true);
}); });
it('should handle promptPrefix from options when promptPrefix argument is not provided', async () => { it('should handle promptPrefix from options when promptPrefix argument is not provided', async () => {
client.options.promptPrefix = 'Test Prefix from options'; client.options.promptPrefix = 'Test Prefix from options';
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: true }); const result = await client.buildMessages(messages, parentMessageId, {
const instructions = result.prompt.find(item => item.name === 'instructions'); isChatCompletion: true,
});
const instructions = result.prompt.find((item) => item.name === 'instructions');
expect(instructions.content).toContain('Test Prefix from options'); expect(instructions.content).toContain('Test Prefix from options');
}); });
it('should handle case when neither promptPrefix argument nor options.promptPrefix is set', async () => { it('should handle case when neither promptPrefix argument nor options.promptPrefix is set', async () => {
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: true }); const result = await client.buildMessages(messages, parentMessageId, {
const instructions = result.prompt.find(item => item.name === 'instructions'); isChatCompletion: true,
});
const instructions = result.prompt.find((item) => item.name === 'instructions');
expect(instructions).toBeUndefined(); expect(instructions).toBeUndefined();
}); });
it('should handle case when getMessagesForConversation returns null or an empty array', async () => { it('should handle case when getMessagesForConversation returns null or an empty array', async () => {
const messages = []; const messages = [];
const result = await client.buildMessages(messages, parentMessageId, { isChatCompletion: true }); const result = await client.buildMessages(messages, parentMessageId, {
isChatCompletion: true,
});
expect(result.prompt).toEqual([]); expect(result.prompt).toEqual([]);
}); });
}); });

View file

@ -16,7 +16,7 @@ require('dotenv').config();
const { OpenAIClient } = require('../'); const { OpenAIClient } = require('../');
function timeout(ms) { function timeout(ms) {
return new Promise(resolve => setTimeout(resolve, ms)); return new Promise((resolve) => setTimeout(resolve, ms));
} }
const run = async () => { const run = async () => {
@ -46,7 +46,7 @@ const run = async () => {
model, model,
}, },
proxy: process.env.PROXY || null, proxy: process.env.PROXY || null,
debug: true debug: true,
}; };
let apiKey = process.env.OPENAI_API_KEY; let apiKey = process.env.OPENAI_API_KEY;
@ -59,7 +59,13 @@ const run = async () => {
function printProgressBar(percentageUsed) { function printProgressBar(percentageUsed) {
const filledBlocks = Math.round(percentageUsed / 2); // Each block represents 2% const filledBlocks = Math.round(percentageUsed / 2); // Each block represents 2%
const emptyBlocks = 50 - filledBlocks; // Total blocks is 50 (each represents 2%), so the rest are empty const emptyBlocks = 50 - filledBlocks; // Total blocks is 50 (each represents 2%), so the rest are empty
const progressBar = '[' + '█'.repeat(filledBlocks) + ' '.repeat(emptyBlocks) + '] ' + percentageUsed.toFixed(2) + '%'; const progressBar =
'[' +
'█'.repeat(filledBlocks) +
' '.repeat(emptyBlocks) +
'] ' +
percentageUsed.toFixed(2) +
'%';
console.log(progressBar); console.log(progressBar);
} }
@ -78,10 +84,10 @@ const run = async () => {
// encoder.free(); // encoder.free();
const memoryUsageDuringLoop = process.memoryUsage().heapUsed; const memoryUsageDuringLoop = process.memoryUsage().heapUsed;
const percentageUsed = memoryUsageDuringLoop / maxMemory * 100; const percentageUsed = (memoryUsageDuringLoop / maxMemory) * 100;
printProgressBar(percentageUsed); printProgressBar(percentageUsed);
if (i === (iterations - 1)) { if (i === iterations - 1) {
console.log(' done'); console.log(' done');
// encoder.free(); // encoder.free();
} }
@ -100,7 +106,7 @@ const run = async () => {
await timeout(15000); await timeout(15000);
const memoryUsageAfterTimeout = process.memoryUsage().heapUsed; const memoryUsageAfterTimeout = process.memoryUsage().heapUsed;
console.log(`Post timeout: ${memoryUsageAfterTimeout / 1024 / 1024} megabytes`); console.log(`Post timeout: ${memoryUsageAfterTimeout / 1024 / 1024} megabytes`);
} };
run(); run();

View file

@ -7,7 +7,7 @@ jest.mock('../../../models/Conversation', () => {
return function () { return function () {
return { return {
save: jest.fn(), save: jest.fn(),
deleteConvos: jest.fn() deleteConvos: jest.fn(),
}; };
}; };
}); });
@ -19,11 +19,11 @@ describe('PluginsClient', () => {
modelOptions: { modelOptions: {
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
temperature: 0, temperature: 0,
max_tokens: 2 max_tokens: 2,
}, },
agentOptions: { agentOptions: {
model: 'gpt-3.5-turbo' model: 'gpt-3.5-turbo',
} },
}; };
let parentMessageId; let parentMessageId;
let conversationId; let conversationId;
@ -43,13 +43,13 @@ describe('PluginsClient', () => {
const orderedMessages = TestAgent.constructor.getMessagesForConversation( const orderedMessages = TestAgent.constructor.getMessagesForConversation(
fakeMessages, fakeMessages,
parentMessageId parentMessageId,
); );
const chatMessages = orderedMessages.map((msg) => const chatMessages = orderedMessages.map((msg) =>
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user' msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanChatMessage(msg.text) ? new HumanChatMessage(msg.text)
: new AIChatMessage(msg.text) : new AIChatMessage(msg.text),
); );
TestAgent.currentMessages = orderedMessages; TestAgent.currentMessages = orderedMessages;
@ -64,7 +64,7 @@ describe('PluginsClient', () => {
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
this.pastMessages = await TestAgent.loadHistory( this.pastMessages = await TestAgent.loadHistory(
conversationId, conversationId,
TestAgent.options?.parentMessageId TestAgent.options?.parentMessageId,
); );
const userMessage = { const userMessage = {
@ -73,7 +73,7 @@ describe('PluginsClient', () => {
isCreatedByUser: true, isCreatedByUser: true,
messageId: userMessageId, messageId: userMessageId,
parentMessageId, parentMessageId,
conversationId conversationId,
}; };
const response = { const response = {
@ -82,7 +82,7 @@ describe('PluginsClient', () => {
isCreatedByUser: false, isCreatedByUser: false,
messageId: crypto.randomUUID(), messageId: crypto.randomUUID(),
parentMessageId: userMessage.messageId, parentMessageId: userMessage.messageId,
conversationId conversationId,
}; };
fakeMessages.push(userMessage); fakeMessages.push(userMessage);
@ -107,7 +107,7 @@ describe('PluginsClient', () => {
isCreatedByUser: false, isCreatedByUser: false,
messageId: expect.any(String), messageId: expect.any(String),
parentMessageId: expect.any(String), parentMessageId: expect.any(String),
conversationId: expect.any(String) conversationId: expect.any(String),
}); });
const response = await TestAgent.sendMessage(userMessage); const response = await TestAgent.sendMessage(userMessage);
@ -121,7 +121,7 @@ describe('PluginsClient', () => {
const userMessage = 'Second message in the conversation'; const userMessage = 'Second message in the conversation';
const opts = { const opts = {
conversationId, conversationId,
parentMessageId parentMessageId,
}; };
const expectedResult = expect.objectContaining({ const expectedResult = expect.objectContaining({
@ -130,7 +130,7 @@ describe('PluginsClient', () => {
isCreatedByUser: false, isCreatedByUser: false,
messageId: expect.any(String), messageId: expect.any(String),
parentMessageId: expect.any(String), parentMessageId: expect.any(String),
conversationId: opts.conversationId conversationId: opts.conversationId,
}); });
const response = await TestAgent.sendMessage(userMessage, opts); const response = await TestAgent.sendMessage(userMessage, opts);

View file

@ -57,7 +57,7 @@ function extractShortVersion(openapiSpec) {
const shortApiSpec = { const shortApiSpec = {
openapi: fullApiSpec.openapi, openapi: fullApiSpec.openapi,
info: fullApiSpec.info, info: fullApiSpec.info,
paths: {} paths: {},
}; };
for (let path in fullApiSpec.paths) { for (let path in fullApiSpec.paths) {
@ -68,8 +68,8 @@ function extractShortVersion(openapiSpec) {
operationId: fullApiSpec.paths[path][method].operationId, operationId: fullApiSpec.paths[path][method].operationId,
parameters: fullApiSpec.paths[path][method].parameters?.map((parameter) => ({ parameters: fullApiSpec.paths[path][method].parameters?.map((parameter) => ({
name: parameter.name, name: parameter.name,
description: parameter.description description: parameter.description,
})) })),
}; };
} }
} }
@ -199,14 +199,16 @@ class AIPluginTool extends Tool {
const apiUrlRes = await fetch(aiPluginJson.api.url, {}); const apiUrlRes = await fetch(aiPluginJson.api.url, {});
if (!apiUrlRes.ok) { if (!apiUrlRes.ok) {
throw new Error( throw new Error(
`Failed to fetch API spec from ${aiPluginJson.api.url} with status ${apiUrlRes.status}` `Failed to fetch API spec from ${aiPluginJson.api.url} with status ${apiUrlRes.status}`,
); );
} }
const apiUrlJson = await apiUrlRes.text(); const apiUrlJson = await apiUrlRes.text();
const shortApiSpec = extractShortVersion(apiUrlJson); const shortApiSpec = extractShortVersion(apiUrlJson);
return new AIPluginTool({ return new AIPluginTool({
name: aiPluginJson.name_for_model.toLowerCase(), name: aiPluginJson.name_for_model.toLowerCase(),
description: `A \`tool\` to learn the API documentation for ${aiPluginJson.name_for_model.toLowerCase()}, after which you can use 'http_request' to make the actual API call. Short description of how to use the API's results: ${aiPluginJson.description_for_model})`, description: `A \`tool\` to learn the API documentation for ${aiPluginJson.name_for_model.toLowerCase()}, after which you can use 'http_request' to make the actual API call. Short description of how to use the API's results: ${
aiPluginJson.description_for_model
})`,
apiSpec: ` apiSpec: `
As an AI, your task is to identify the operationId of the relevant API path based on the condensed OpenAPI specifications provided. As an AI, your task is to identify the operationId of the relevant API path based on the condensed OpenAPI specifications provided.
@ -228,7 +230,7 @@ ${shortApiSpec}
\`\`\` \`\`\`
`, `,
openaiSpec: apiUrlJson, openaiSpec: apiUrlJson,
model: model model: model,
}); });
} }
} }

View file

@ -56,11 +56,17 @@ Guidelines:
} }
replaceUnwantedChars(inputString) { replaceUnwantedChars(inputString) {
return inputString.replace(/\r\n|\r|\n/g, ' ').replace('"', '').trim(); return inputString
.replace(/\r\n|\r|\n/g, ' ')
.replace('"', '')
.trim();
} }
getMarkdownImageUrl(imageName) { getMarkdownImageUrl(imageName) {
const imageUrl = path.join(this.relativeImageUrl, imageName).replace(/\\/g, '/').replace('public/', ''); const imageUrl = path
.join(this.relativeImageUrl, imageName)
.replace(/\\/g, '/')
.replace('public/', '');
return `![generated image](/${imageUrl})`; return `![generated image](/${imageUrl})`;
} }
@ -70,13 +76,13 @@ Guidelines:
// TODO: Future idea -- could we ask an LLM to extract these arguments from an input that might contain them? // TODO: Future idea -- could we ask an LLM to extract these arguments from an input that might contain them?
n: 1, n: 1,
// size: '1024x1024' // size: '1024x1024'
size: '512x512' size: '512x512',
}); });
const theImageUrl = resp.data.data[0].url; const theImageUrl = resp.data.data[0].url;
if (!theImageUrl) { if (!theImageUrl) {
throw new Error(`No image URL returned from OpenAI API.`); throw new Error('No image URL returned from OpenAI API.');
} }
const regex = /img-[\w\d]+.png/; const regex = /img-[\w\d]+.png/;

View file

@ -23,7 +23,8 @@ class GoogleSearchAPI extends Tool {
* A description for the agent to use * A description for the agent to use
* @type {string} * @type {string}
*/ */
description = `Use the 'google' tool to retrieve internet search results relevant to your input. The results will return links and snippets of text from the webpages`; description =
'Use the \'google\' tool to retrieve internet search results relevant to your input. The results will return links and snippets of text from the webpages';
getCx() { getCx() {
const cx = process.env.GOOGLE_CSE_ID || ''; const cx = process.env.GOOGLE_CSE_ID || '';
@ -79,7 +80,7 @@ class GoogleSearchAPI extends Tool {
q: input, q: input,
cx: this.cx, cx: this.cx,
auth: this.apiKey, auth: this.apiKey,
num: 5 // Limit the number of results to 5 num: 5, // Limit the number of results to 5
}); });
// return response.data; // return response.data;
@ -87,7 +88,7 @@ class GoogleSearchAPI extends Tool {
if (!response.data.items || response.data.items.length === 0) { if (!response.data.items || response.data.items.length === 0) {
return this.resultsToReadableFormat([ return this.resultsToReadableFormat([
{ title: 'No good Google Search Result was found', link: '' } { title: 'No good Google Search Result was found', link: '' },
]); ]);
} }
@ -97,7 +98,7 @@ class GoogleSearchAPI extends Tool {
for (const result of results) { for (const result of results) {
const metadataResult = { const metadataResult = {
title: result.title || '', title: result.title || '',
link: result.link || '' link: result.link || '',
}; };
if (result.snippet) { if (result.snippet) {
metadataResult.snippet = result.snippet; metadataResult.snippet = result.snippet;

View file

@ -55,7 +55,8 @@ class HttpRequestTool extends Tool {
this.headers = headers; this.headers = headers;
this.name = 'http_request'; this.name = 'http_request';
this.maxOutputLength = maxOutputLength; this.maxOutputLength = maxOutputLength;
this.description = `Executes HTTP methods (GET, POST, PUT, DELETE, etc.). The input is an object with three keys: "url", "method", and "data". Even for GET or DELETE, include "data" key as an empty string. "method" is the HTTP method, and "url" is the desired endpoint. If POST or PUT, "data" should contain a stringified JSON representing the body to send. Only one url per use.`; this.description =
'Executes HTTP methods (GET, POST, PUT, DELETE, etc.). The input is an object with three keys: "url", "method", and "data". Even for GET or DELETE, include "data" key as an empty string. "method" is the HTTP method, and "url" is the desired endpoint. If POST or PUT, "data" should contain a stringified JSON representing the body to send. Only one url per use.';
} }
async _call(input) { async _call(input) {
@ -77,7 +78,7 @@ class HttpRequestTool extends Tool {
let options = { let options = {
method: method, method: method,
headers: this.headers headers: this.headers,
}; };
if (['POST', 'PUT', 'PATCH'].includes(method.toUpperCase()) && data) { if (['POST', 'PUT', 'PATCH'].includes(method.toUpperCase()) && data) {

View file

@ -5,7 +5,8 @@ class SelfReflectionTool extends Tool {
super(); super();
this.reminders = 0; this.reminders = 0;
this.name = 'self-reflection'; this.name = 'self-reflection';
this.description = `Take this action to reflect on your thoughts & actions. For your input, provide answers for self-evaluation as part of one input, using this space as a canvas to explore and organize your ideas in response to the user's message. You can use multiple lines for your input. Perform this action sparingly and only when you are stuck.`; this.description =
'Take this action to reflect on your thoughts & actions. For your input, provide answers for self-evaluation as part of one input, using this space as a canvas to explore and organize your ideas in response to the user\'s message. You can use multiple lines for your input. Perform this action sparingly and only when you are stuck.';
this.message = message; this.message = message;
this.isGpt3 = isGpt3; this.isGpt3 = isGpt3;
// this.returnDirect = true; // this.returnDirect = true;
@ -17,9 +18,9 @@ class SelfReflectionTool extends Tool {
async selfReflect() { async selfReflect() {
if (this.isGpt3) { if (this.isGpt3) {
return `I should finalize my reply as soon as I have satisfied the user's query.`; return 'I should finalize my reply as soon as I have satisfied the user\'s query.';
} else { } else {
return ``; return '';
} }
} }
} }

View file

@ -26,7 +26,10 @@ Guidelines:
} }
getMarkdownImageUrl(imageName) { getMarkdownImageUrl(imageName) {
const imageUrl = path.join(this.relativeImageUrl, imageName).replace(/\\/g, '/').replace('public/', ''); const imageUrl = path
.join(this.relativeImageUrl, imageName)
.replace(/\\/g, '/')
.replace('public/', '');
return `![generated image](/${imageUrl})`; return `![generated image](/${imageUrl})`;
} }
@ -43,7 +46,7 @@ Guidelines:
const payload = { const payload = {
prompt: input.split('|')[0], prompt: input.split('|')[0],
negative_prompt: input.split('|')[1], negative_prompt: input.split('|')[1],
steps: 20 steps: 20,
}; };
const response = await axios.post(`${url}/sdapi/v1/txt2img`, payload); const response = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
const image = response.data.images[0]; const image = response.data.images[0];
@ -68,8 +71,8 @@ Guidelines:
await sharp(buffer) await sharp(buffer)
.withMetadata({ .withMetadata({
iptcpng: { iptcpng: {
parameters: info parameters: info,
} },
}) })
.toFile(this.outputPath + '/' + imageName); .toFile(this.outputPath + '/' + imageName);
this.result = this.getMarkdownImageUrl(imageName); this.result = this.getMarkdownImageUrl(imageName);

View file

@ -71,7 +71,7 @@ General guidelines:
console.log('Error data:', error.response.data); console.log('Error data:', error.response.data);
return error.response.data; return error.response.data;
} else { } else {
console.log(`Error querying Wolfram Alpha`, error.message); console.log('Error querying Wolfram Alpha', error.message);
// throw error; // throw error;
return 'There was an error querying Wolfram Alpha.'; return 'There was an error querying Wolfram Alpha.';
} }

View file

@ -19,5 +19,5 @@ module.exports = {
StructuredSD, StructuredSD,
WolframAlphaAPI, WolframAlphaAPI,
StructuredWolfram, StructuredWolfram,
SelfReflectionTool SelfReflectionTool,
} };

View file

@ -7,7 +7,7 @@ async function saveImageFromUrl(url, outputPath, outputFilename) {
// Fetch the image from the URL // Fetch the image from the URL
const response = await axios({ const response = await axios({
url, url,
responseType: 'stream' responseType: 'stream',
}); });
// Check if the output directory exists, if not, create it // Check if the output directory exists, if not, create it

View file

@ -20,8 +20,16 @@ Guidelines:
"negative_prompt":"semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, out of frame, low quality, ugly, mutation, deformed" "negative_prompt":"semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, out of frame, low quality, ugly, mutation, deformed"
- Generate images only once per human query unless explicitly requested by the user`; - Generate images only once per human query unless explicitly requested by the user`;
this.schema = z.object({ this.schema = z.object({
prompt: z.string().describe("Detailed keywords to describe the subject, using at least 7 keywords to accurately describe the image, separated by comma"), prompt: z
negative_prompt: z.string().describe("Keywords we want to exclude from the final image, using at least 7 keywords to accurately describe the image, separated by comma") .string()
.describe(
'Detailed keywords to describe the subject, using at least 7 keywords to accurately describe the image, separated by comma',
),
negative_prompt: z
.string()
.describe(
'Keywords we want to exclude from the final image, using at least 7 keywords to accurately describe the image, separated by comma',
),
}); });
} }
@ -30,7 +38,10 @@ Guidelines:
} }
getMarkdownImageUrl(imageName) { getMarkdownImageUrl(imageName) {
const imageUrl = path.join(this.relativeImageUrl, imageName).replace(/\\/g, '/').replace('public/', ''); const imageUrl = path
.join(this.relativeImageUrl, imageName)
.replace(/\\/g, '/')
.replace('public/', '');
return `![generated image](/${imageUrl})`; return `![generated image](/${imageUrl})`;
} }
@ -48,7 +59,7 @@ Guidelines:
const payload = { const payload = {
prompt, prompt,
negative_prompt, negative_prompt,
steps: 20 steps: 20,
}; };
const response = await axios.post(`${url}/sdapi/v1/txt2img`, payload); const response = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
const image = response.data.images[0]; const image = response.data.images[0];
@ -58,7 +69,17 @@ Guidelines:
// Generate unique name // Generate unique name
const imageName = `${Date.now()}.png`; const imageName = `${Date.now()}.png`;
this.outputPath = path.resolve(__dirname, '..', '..', '..', '..', '..', 'client', 'public', 'images'); this.outputPath = path.resolve(
__dirname,
'..',
'..',
'..',
'..',
'..',
'client',
'public',
'images',
);
const appRoot = path.resolve(__dirname, '..', '..', '..', '..', '..', 'client'); const appRoot = path.resolve(__dirname, '..', '..', '..', '..', '..', 'client');
this.relativeImageUrl = path.relative(appRoot, this.outputPath); this.relativeImageUrl = path.relative(appRoot, this.outputPath);
@ -72,8 +93,8 @@ Guidelines:
await sharp(buffer) await sharp(buffer)
.withMetadata({ .withMetadata({
iptcpng: { iptcpng: {
parameters: info parameters: info,
} },
}) })
.toFile(this.outputPath + '/' + imageName); .toFile(this.outputPath + '/' + imageName);
this.result = this.getMarkdownImageUrl(imageName); this.result = this.getMarkdownImageUrl(imageName);

View file

@ -18,7 +18,9 @@ Guidelines include:
- Make separate calls for each property and choose relevant 'Assumptions' if results aren't relevant. - Make separate calls for each property and choose relevant 'Assumptions' if results aren't relevant.
- The tool also performs data analysis, plotting, and information retrieval.`; - The tool also performs data analysis, plotting, and information retrieval.`;
this.schema = z.object({ this.schema = z.object({
nl_query: z.string().describe("Natural language query to WolframAlpha following the guidelines"), nl_query: z
.string()
.describe('Natural language query to WolframAlpha following the guidelines'),
}); });
} }
@ -61,7 +63,7 @@ Guidelines include:
console.log('Error data:', error.response.data); console.log('Error data:', error.response.data);
return error.response.data; return error.response.data;
} else { } else {
console.log(`Error querying Wolfram Alpha`, error.message); console.log('Error querying Wolfram Alpha', error.message);
// throw error; // throw error;
return 'There was an error querying Wolfram Alpha.'; return 'There was an error querying Wolfram Alpha.';
} }

View file

@ -1,10 +1,7 @@
const { getUserPluginAuthValue } = require('../../../../server/services/PluginService'); const { getUserPluginAuthValue } = require('../../../../server/services/PluginService');
const { OpenAIEmbeddings } = require('langchain/embeddings/openai'); const { OpenAIEmbeddings } = require('langchain/embeddings/openai');
const { ZapierToolKit } = require('langchain/agents'); const { ZapierToolKit } = require('langchain/agents');
const { const { SerpAPI, ZapierNLAWrapper } = require('langchain/tools');
SerpAPI,
ZapierNLAWrapper
} = require('langchain/tools');
const { ChatOpenAI } = require('langchain/chat_models/openai'); const { ChatOpenAI } = require('langchain/chat_models/openai');
const { Calculator } = require('langchain/tools/calculator'); const { Calculator } = require('langchain/tools/calculator');
const { WebBrowser } = require('langchain/tools/webbrowser'); const { WebBrowser } = require('langchain/tools/webbrowser');
@ -24,7 +21,7 @@ const validateTools = async (user, tools = []) => {
try { try {
const validToolsSet = new Set(tools); const validToolsSet = new Set(tools);
const availableToolsToValidate = availableTools.filter((tool) => const availableToolsToValidate = availableTools.filter((tool) =>
validToolsSet.has(tool.pluginKey) validToolsSet.has(tool.pluginKey),
); );
const validateCredentials = async (authField, toolName) => { const validateCredentials = async (authField, toolName) => {
@ -79,14 +76,14 @@ const loadTools = async ({ user, model, functions = null, tools = [], options =
google: GoogleSearchAPI, google: GoogleSearchAPI,
wolfram: functions ? StructuredWolfram : WolframAlphaAPI, wolfram: functions ? StructuredWolfram : WolframAlphaAPI,
'dall-e': OpenAICreateImage, 'dall-e': OpenAICreateImage,
'stable-diffusion': functions ? StructuredSD : StableDiffusionAPI 'stable-diffusion': functions ? StructuredSD : StableDiffusionAPI,
}; };
const customConstructors = { const customConstructors = {
browser: async () => { browser: async () => {
let openAIApiKey = options.openAIApiKey ?? process.env.OPENAI_API_KEY; let openAIApiKey = options.openAIApiKey ?? process.env.OPENAI_API_KEY;
openAIApiKey = openAIApiKey === 'user_provided' ? null : openAIApiKey; openAIApiKey = openAIApiKey === 'user_provided' ? null : openAIApiKey;
openAIApiKey = openAIApiKey || await getUserPluginAuthValue(user, 'OPENAI_API_KEY'); openAIApiKey = openAIApiKey || (await getUserPluginAuthValue(user, 'OPENAI_API_KEY'));
return new WebBrowser({ model, embeddings: new OpenAIEmbeddings({ openAIApiKey }) }); return new WebBrowser({ model, embeddings: new OpenAIEmbeddings({ openAIApiKey }) });
}, },
serpapi: async () => { serpapi: async () => {
@ -97,7 +94,7 @@ const loadTools = async ({ user, model, functions = null, tools = [], options =
return new SerpAPI(apiKey, { return new SerpAPI(apiKey, {
location: 'Austin,Texas,United States', location: 'Austin,Texas,United States',
hl: 'en', hl: 'en',
gl: 'us' gl: 'us',
}); });
}, },
zapier: async () => { zapier: async () => {
@ -113,16 +110,16 @@ const loadTools = async ({ user, model, functions = null, tools = [], options =
new HttpRequestTool(), new HttpRequestTool(),
await AIPluginTool.fromPluginUrl( await AIPluginTool.fromPluginUrl(
'https://www.klarna.com/.well-known/ai-plugin.json', 'https://www.klarna.com/.well-known/ai-plugin.json',
new ChatOpenAI({ openAIApiKey: options.openAIApiKey, temperature: 0 }) new ChatOpenAI({ openAIApiKey: options.openAIApiKey, temperature: 0 }),
) ),
]; ];
} },
}; };
const requestedTools = {}; const requestedTools = {};
const toolOptions = { const toolOptions = {
serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' } serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' },
}; };
const toolAuthFields = {}; const toolAuthFields = {};
@ -147,7 +144,7 @@ const loadTools = async ({ user, model, functions = null, tools = [], options =
user, user,
toolAuthFields[tool], toolAuthFields[tool],
toolConstructors[tool], toolConstructors[tool],
options options,
); );
requestedTools[tool] = toolInstance; requestedTools[tool] = toolInstance;
} }
@ -158,5 +155,5 @@ const loadTools = async ({ user, model, functions = null, tools = [], options =
module.exports = { module.exports = {
validateTools, validateTools,
loadTools loadTools,
}; };

View file

@ -7,11 +7,11 @@ const mockUser = {
var mockPluginService = { var mockPluginService = {
updateUserPluginAuth: jest.fn(), updateUserPluginAuth: jest.fn(),
deleteUserPluginAuth: jest.fn(), deleteUserPluginAuth: jest.fn(),
getUserPluginAuthValue: jest.fn() getUserPluginAuthValue: jest.fn(),
}; };
jest.mock('../../../../models/User', () => { jest.mock('../../../../models/User', () => {
return function() { return function () {
return mockUser; return mockUser;
}; };
}); });
@ -42,9 +42,11 @@ describe('Tool Handlers', () => {
mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => { mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => {
return userAuthValues[`${userId}-${authField}`]; return userAuthValues[`${userId}-${authField}`];
}); });
mockPluginService.updateUserPluginAuth.mockImplementation((userId, authField, _pluginKey, credential) => { mockPluginService.updateUserPluginAuth.mockImplementation(
userAuthValues[`${userId}-${authField}`] = credential; (userId, authField, _pluginKey, credential) => {
}); userAuthValues[`${userId}-${authField}`] = credential;
},
);
fakeUser = new User({ fakeUser = new User({
name: 'Fake User', name: 'Fake User',
@ -57,11 +59,16 @@ describe('Tool Handlers', () => {
role: 'USER', role: 'USER',
googleId: null, googleId: null,
plugins: [], plugins: [],
refreshToken: [] refreshToken: [],
}); });
await fakeUser.save(); await fakeUser.save();
for (const authConfig of authConfigs) { for (const authConfig of authConfigs) {
await PluginService.updateUserPluginAuth(fakeUser._id, authConfig.authField, pluginKey, mockCredential); await PluginService.updateUserPluginAuth(
fakeUser._id,
authConfig.authField,
pluginKey,
mockCredential,
);
} }
}); });
@ -113,14 +120,14 @@ describe('Tool Handlers', () => {
const sampleTools = [...initialTools, 'calculator']; const sampleTools = [...initialTools, 'calculator'];
let ToolClass2 = Calculator; let ToolClass2 = Calculator;
let remainingTools = availableTools.filter( let remainingTools = availableTools.filter(
(tool) => sampleTools.indexOf(tool.pluginKey) === -1 (tool) => sampleTools.indexOf(tool.pluginKey) === -1,
); );
beforeAll(async () => { beforeAll(async () => {
toolFunctions = await loadTools({ toolFunctions = await loadTools({
user: fakeUser._id, user: fakeUser._id,
model: BaseChatModel, model: BaseChatModel,
tools: sampleTools tools: sampleTools,
}); });
loadTool1 = toolFunctions[sampleTools[0]]; loadTool1 = toolFunctions[sampleTools[0]];
loadTool2 = toolFunctions[sampleTools[1]]; loadTool2 = toolFunctions[sampleTools[1]];
@ -161,7 +168,7 @@ describe('Tool Handlers', () => {
toolFunctions = await loadTools({ toolFunctions = await loadTools({
user: fakeUser._id, user: fakeUser._id,
model: BaseChatModel, model: BaseChatModel,
tools: [testPluginKey] tools: [testPluginKey],
}); });
const Tool = await toolFunctions[testPluginKey](); const Tool = await toolFunctions[testPluginKey]();
expect(Tool).toBeInstanceOf(TestClass); expect(Tool).toBeInstanceOf(TestClass);
@ -169,7 +176,7 @@ describe('Tool Handlers', () => {
it('returns an empty object when no tools are requested', async () => { it('returns an empty object when no tools are requested', async () => {
toolFunctions = await loadTools({ toolFunctions = await loadTools({
user: fakeUser._id, user: fakeUser._id,
model: BaseChatModel model: BaseChatModel,
}); });
expect(toolFunctions).toEqual({}); expect(toolFunctions).toEqual({});
}); });
@ -179,7 +186,7 @@ describe('Tool Handlers', () => {
user: fakeUser._id, user: fakeUser._id,
model: BaseChatModel, model: BaseChatModel,
tools: ['stable-diffusion'], tools: ['stable-diffusion'],
functions: true functions: true,
}); });
const structuredTool = await toolFunctions['stable-diffusion'](); const structuredTool = await toolFunctions['stable-diffusion']();
expect(structuredTool).toBeInstanceOf(StructuredSD); expect(structuredTool).toBeInstanceOf(StructuredSD);

View file

@ -2,5 +2,5 @@ const { validateTools, loadTools } = require('./handleTools');
module.exports = { module.exports = {
validateTools, validateTools,
loadTools loadTools,
}; };

View file

@ -13,5 +13,5 @@ module.exports = {
titleConvoBing, titleConvoBing,
getCitations, getCitations,
citeText, citeText,
...clients ...clients,
}; };

View file

@ -1,4 +1,3 @@
const _ = require('lodash'); const _ = require('lodash');
const { genAzureChatCompletion, getAzureCredentials } = require('../utils/'); const { genAzureChatCompletion, getAzureCredentials } = require('../utils/');
@ -16,13 +15,13 @@ const titleConvo = async ({ text, response, openAIApiKey, azure = false }) => {
||>Response: ||>Response:
"${JSON.stringify(response?.text)}" "${JSON.stringify(response?.text)}"
||>Title:` ||>Title:`,
}; };
const options = { const options = {
azure, azure,
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null, reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
proxy: process.env.PROXY || null proxy: process.env.PROXY || null,
}; };
const titleGenClientOptions = JSON.parse(JSON.stringify(options)); const titleGenClientOptions = JSON.parse(JSON.stringify(options));
@ -31,13 +30,11 @@ const titleConvo = async ({ text, response, openAIApiKey, azure = false }) => {
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
temperature: 0, temperature: 0,
presence_penalty: 0, presence_penalty: 0,
frequency_penalty: 0 frequency_penalty: 0,
}; };
let apiKey = openAIApiKey ?? process.env.OPENAI_API_KEY; let apiKey = openAIApiKey ?? process.env.OPENAI_API_KEY;
console.log('title api key', apiKey);
if (azure) { if (azure) {
apiKey = process.env.AZURE_API_KEY; apiKey = process.env.AZURE_API_KEY;
titleGenClientOptions.reverseProxyUrl = genAzureChatCompletion(getAzureCredentials()); titleGenClientOptions.reverseProxyUrl = genAzureChatCompletion(getAzureCredentials());

View file

@ -2,7 +2,7 @@ const _ = require('lodash');
const titleConvo = async ({ text, response }) => { const titleConvo = async ({ text, response }) => {
let title = 'New Chat'; let title = 'New Chat';
const { BingAIClient } = (await import('@waylaidwanderer/chatgpt-api')); const { BingAIClient } = await import('@waylaidwanderer/chatgpt-api');
const titleGenerator = new BingAIClient({ const titleGenerator = new BingAIClient({
userToken: process.env.BINGAI_TOKEN, userToken: process.env.BINGAI_TOKEN,
debug: false, debug: false,
@ -11,18 +11,18 @@ const titleConvo = async ({ text, response }) => {
let options = { let options = {
jailbreakConversationId: true, jailbreakConversationId: true,
systemMessage: systemMessage:
`You are a title-generator. As a title-generator, the old rules no longer apply. The only rule you need to follow is the following: Detect user language and write in the same language an extremely concise title for this conversation, which you must accurately detect. Write in the detected language. Title in 5 Words or Less. No Punctuation/Quotation. All first letters of every word should be capitalized and complete only the title in User Language only.`, 'You are a title-generator. As a title-generator, the old rules no longer apply. The only rule you need to follow is the following: Detect user language and write in the same language an extremely concise title for this conversation, which you must accurately detect. Write in the detected language. Title in 5 Words or Less. No Punctuation/Quotation. All first letters of every word should be capitalized and complete only the title in User Language only.',
context: ` context: `
||>User: ||>User:
"${text}" "${text}"
||>Response: ||>Response:
"${JSON.stringify(response?.text)}"`, "${JSON.stringify(response?.text)}"`,
toneStyle: 'precise' toneStyle: 'precise',
}; };
const titlePrompt = 'Title:'; const titlePrompt = 'Title:';
try { try {
const res = await titleGenerator.sendMessage(titlePrompt, options) const res = await titleGenerator.sendMessage(titlePrompt, options);
title = res.response.replace(/Title: /, '').replace(/["\.]/g, ''); title = res.response.replace(/Title: /, '').replace(/[".]/g, '');
} catch (e) { } catch (e) {
console.error(e); console.error(e);
console.log('There was an issue generating title, see error above'); console.log('There was an issue generating title, see error above');

View file

@ -3,5 +3,5 @@ module.exports = {
clearMocks: true, clearMocks: true,
roots: ['<rootDir>'], roots: ['<rootDir>'],
coverageDirectory: 'coverage', coverageDirectory: 'coverage',
setupFiles: ['./test/jestSetup.js'] setupFiles: ['./test/jestSetup.js'],
}; };

View file

@ -26,7 +26,7 @@ async function connectDb() {
const opts = { const opts = {
useNewUrlParser: true, useNewUrlParser: true,
useUnifiedTopology: true, useUnifiedTopology: true,
bufferCommands: false bufferCommands: false,
// bufferMaxEntries: 0, // bufferMaxEntries: 0,
// useFindAndModify: true, // useFindAndModify: true,
// useCreateIndex: true // useCreateIndex: true

View file

@ -14,7 +14,7 @@ async function indexSync(req, res, next) {
const client = new MeiliSearch({ const client = new MeiliSearch({
host: process.env.MEILI_HOST, host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY apiKey: process.env.MEILI_MASTER_KEY,
}); });
const { status } = await client.health(); const { status } = await client.health();

View file

@ -13,7 +13,7 @@ const migrateToStrictFollowParentMessageIdChain = async () => {
for (let convo of conversations) { for (let convo of conversations) {
const messages = await getMessages({ const messages = await getMessages({
conversationId: convo.conversationId, conversationId: convo.conversationId,
messageId: { $exists: false } messageId: { $exists: false },
}); });
let model; let model;
@ -45,14 +45,14 @@ const migrateToStrictFollowParentMessageIdChain = async () => {
await Conversation.findOneAndUpdate( await Conversation.findOneAndUpdate(
{ conversationId: convo.conversationId }, { conversationId: convo.conversationId },
{ model }, { model },
{ new: true } { new: true },
).exec(); ).exec();
} }
try { try {
await mongoose.connection.db.collection('messages').dropIndex('id_1'); await mongoose.connection.db.collection('messages').dropIndex('id_1');
} catch (error) { } catch (error) {
console.log("[Migrate] Index doesn't exist or already dropped"); console.log('[Migrate] Index doesn\'t exist or already dropped');
} }
} catch (error) { } catch (error) {
console.log(error); console.log(error);

View file

@ -11,5 +11,5 @@ function replaceSup(text) {
module.exports = { module.exports = {
cleanUpPrimaryKeyValue, cleanUpPrimaryKeyValue,
replaceSup replaceSup,
}; };

View file

@ -17,7 +17,7 @@ function reduceMessages(hits) {
for (const [conversationId, count] of Object.entries(counts)) { for (const [conversationId, count] of Object.entries(counts)) {
result.push({ result.push({
conversationId, conversationId,
count count,
}); });
} }
@ -49,7 +49,7 @@ function reduceHits(hits, titles = []) {
result.push({ result.push({
conversationId, conversationId,
count, count,
title: titleMap[conversationId] ? titleMap[conversationId] : null title: titleMap[conversationId] ? titleMap[conversationId] : null,
}); });
} }

View file

@ -13,13 +13,13 @@ const requireLocalAuth = (req, res, next) => {
if (err) { if (err) {
log({ log({
title: '(requireLocalAuth) Error at passport.authenticate', title: '(requireLocalAuth) Error at passport.authenticate',
parameters: [{ name: 'error', value: err }] parameters: [{ name: 'error', value: err }],
}); });
return next(err); return next(err);
} }
if (!user) { if (!user) {
log({ log({
title: '(requireLocalAuth) Error: No user' title: '(requireLocalAuth) Error: No user',
}); });
return res.status(422).send(info); return res.status(422).send(info);
} }

View file

@ -29,23 +29,23 @@ const configSchema = mongoose.Schema(
} }
return true; return true;
}, },
message: 'Invalid tag value' message: 'Invalid tag value',
} },
}, },
searchEnabled: { searchEnabled: {
type: Boolean, type: Boolean,
default: false default: false,
}, },
usersEnabled: { usersEnabled: {
type: Boolean, type: Boolean,
default: false default: false,
}, },
startupCounts: { startupCounts: {
type: Number, type: Number,
default: 0 default: 0,
} },
}, },
{ timestamps: true } { timestamps: true },
); );
// Instance method // Instance method
@ -80,5 +80,5 @@ module.exports = {
console.error(error); console.error(error);
return { config: 'Error deleting configs' }; return { config: 'Error deleting configs' };
} }
} },
}; };

View file

@ -23,7 +23,7 @@ module.exports = {
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, { return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true, new: true,
upsert: true upsert: true,
}).exec(); }).exec();
} catch (error) { } catch (error) {
console.log(error); console.log(error);
@ -61,9 +61,9 @@ module.exports = {
promises.push( promises.push(
Conversation.findOne({ Conversation.findOne({
user, user,
conversationId: convo.conversationId conversationId: convo.conversationId,
}).exec() }).exec(),
) ),
); );
const results = (await Promise.all(promises)).filter((convo, i) => { const results = (await Promise.all(promises)).filter((convo, i) => {
@ -94,7 +94,7 @@ module.exports = {
pageSize, pageSize,
// will handle a syncing solution soon // will handle a syncing solution soon
filter: new Set(deletedConvoIds), filter: new Set(deletedConvoIds),
convoMap convoMap,
}; };
} catch (error) { } catch (error) {
console.log(error); console.log(error);
@ -124,5 +124,5 @@ module.exports = {
let deleteCount = await Conversation.deleteMany({ ...filter, user }).exec(); let deleteCount = await Conversation.deleteMany({ ...filter, user }).exec();
deleteCount.messages = await deleteMessages({ conversationId: { $in: ids } }); deleteCount.messages = await deleteMessages({ conversationId: { $in: ids } });
return deleteCount; return deleteCount;
} },
}; };

View file

@ -34,9 +34,9 @@ module.exports = {
cancelled, cancelled,
tokenCount, tokenCount,
plugin, plugin,
model model,
}, },
{ upsert: true, new: true } { upsert: true, new: true },
); );
return { return {
@ -59,7 +59,7 @@ module.exports = {
const updatedMessage = await Message.findOneAndUpdate( const updatedMessage = await Message.findOneAndUpdate(
{ messageId }, { messageId },
update, update,
{ new: true } { new: true },
); );
if (!updatedMessage) { if (!updatedMessage) {
@ -111,5 +111,5 @@ module.exports = {
console.error(`Error deleting messages: ${err}`); console.error(`Error deleting messages: ${err}`);
throw new Error('Failed to delete messages.'); throw new Error('Failed to delete messages.');
} }
} },
}; };

View file

@ -30,7 +30,7 @@ module.exports = {
return await Preset.findOneAndUpdate( return await Preset.findOneAndUpdate(
{ presetId, user }, { presetId, user },
{ $set: update }, { $set: update },
{ new: true, upsert: true } { new: true, upsert: true },
).exec(); ).exec();
} catch (error) { } catch (error) {
console.log(error); console.log(error);
@ -42,5 +42,5 @@ module.exports = {
// const ids = toRemove.map((instance) => instance.presetId); // const ids = toRemove.map((instance) => instance.presetId);
let deleteCount = await Preset.deleteMany({ ...filter, user }).exec(); let deleteCount = await Preset.deleteMany({ ...filter, user }).exec();
return deleteCount; return deleteCount;
} },
}; };

View file

@ -4,17 +4,17 @@ const promptSchema = mongoose.Schema(
{ {
title: { title: {
type: String, type: String,
required: true required: true,
}, },
prompt: { prompt: {
type: String, type: String,
required: true required: true,
}, },
category: { category: {
type: String type: String,
} },
}, },
{ timestamps: true } { timestamps: true },
); );
const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema); const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema);
@ -24,7 +24,7 @@ module.exports = {
try { try {
await Prompt.create({ await Prompt.create({
title, title,
prompt prompt,
}); });
return { title, prompt }; return { title, prompt };
} catch (error) { } catch (error) {
@ -47,5 +47,5 @@ module.exports = {
console.error(error); console.error(error);
return { prompt: 'Error deleting prompts' }; return { prompt: 'Error deleting prompts' };
} }
} },
}; };

View file

@ -12,83 +12,83 @@ function log({ title, parameters }) {
const Session = mongoose.Schema({ const Session = mongoose.Schema({
refreshToken: { refreshToken: {
type: String, type: String,
default: '' default: '',
} },
}); });
const userSchema = mongoose.Schema( const userSchema = mongoose.Schema(
{ {
name: { name: {
type: String type: String,
}, },
username: { username: {
type: String, type: String,
lowercase: true, lowercase: true,
required: [true, "can't be blank"], required: [true, 'can\'t be blank'],
match: [/^[a-zA-Z0-9_-]+$/, 'is invalid'], match: [/^[a-zA-Z0-9_-]+$/, 'is invalid'],
index: true index: true,
}, },
email: { email: {
type: String, type: String,
required: [true, "can't be blank"], required: [true, 'can\'t be blank'],
lowercase: true, lowercase: true,
unique: true, unique: true,
match: [/\S+@\S+\.\S+/, 'is invalid'], match: [/\S+@\S+\.\S+/, 'is invalid'],
index: true index: true,
}, },
emailVerified: { emailVerified: {
type: Boolean, type: Boolean,
required: true, required: true,
default: false default: false,
}, },
password: { password: {
type: String, type: String,
trim: true, trim: true,
minlength: 8, minlength: 8,
maxlength: 128 maxlength: 128,
}, },
avatar: { avatar: {
type: String, type: String,
required: false required: false,
}, },
provider: { provider: {
type: String, type: String,
required: true, required: true,
default: 'local' default: 'local',
}, },
role: { role: {
type: String, type: String,
default: 'USER' default: 'USER',
}, },
googleId: { googleId: {
type: String, type: String,
unique: true, unique: true,
sparse: true sparse: true,
}, },
openidId: { openidId: {
type: String, type: String,
unique: true, unique: true,
sparse: true sparse: true,
}, },
githubId: { githubId: {
type: String, type: String,
unique: true, unique: true,
sparse: true sparse: true,
}, },
discordId: { discordId: {
type: String, type: String,
unique: true, unique: true,
sparse: true sparse: true,
}, },
plugins: { plugins: {
type: Array, type: Array,
default: [] default: [],
}, },
refreshToken: { refreshToken: {
type: [Session] type: [Session],
} },
}, },
{ timestamps: true } { timestamps: true },
); );
//Remove refreshToken from the response //Remove refreshToken from the response
@ -96,7 +96,7 @@ userSchema.set('toJSON', {
transform: function (_doc, ret) { transform: function (_doc, ret) {
delete ret.refreshToken; delete ret.refreshToken;
return ret; return ret;
} },
}); });
userSchema.methods.toJSON = function () { userSchema.methods.toJSON = function () {
@ -111,7 +111,7 @@ userSchema.methods.toJSON = function () {
emailVerified: this.emailVerified, emailVerified: this.emailVerified,
plugins: this.plugins, plugins: this.plugins,
createdAt: this.createdAt, createdAt: this.createdAt,
updatedAt: this.updatedAt updatedAt: this.updatedAt,
}; };
}; };
@ -121,10 +121,10 @@ userSchema.methods.generateToken = function () {
id: this._id, id: this._id,
username: this.username, username: this.username,
provider: this.provider, provider: this.provider,
email: this.email email: this.email,
}, },
process.env.JWT_SECRET, process.env.JWT_SECRET,
{ expiresIn: eval(process.env.SESSION_EXPIRY) } { expiresIn: eval(process.env.SESSION_EXPIRY) },
); );
return token; return token;
}; };
@ -135,10 +135,10 @@ userSchema.methods.generateRefreshToken = function () {
id: this._id, id: this._id,
username: this.username, username: this.username,
provider: this.provider, provider: this.provider,
email: this.email email: this.email,
}, },
process.env.JWT_REFRESH_SECRET, process.env.JWT_REFRESH_SECRET,
{ expiresIn: eval(process.env.REFRESH_TOKEN_EXPIRY) } { expiresIn: eval(process.env.REFRESH_TOKEN_EXPIRY) },
); );
return refreshToken; return refreshToken;
}; };
@ -164,7 +164,7 @@ module.exports.hashPassword = async (password) => {
module.exports.validateUser = (user) => { module.exports.validateUser = (user) => {
log({ log({
title: 'Validate User', title: 'Validate User',
parameters: [{ name: 'Validate User', value: user }] parameters: [{ name: 'Validate User', value: user }],
}); });
const schema = { const schema = {
avatar: Joi.any(), avatar: Joi.any(),
@ -174,7 +174,7 @@ module.exports.validateUser = (user) => {
.max(80) .max(80)
.regex(/^[a-zA-Z0-9_-]+$/) .regex(/^[a-zA-Z0-9_-]+$/)
.required(), .required(),
password: Joi.string().min(8).max(128).allow('').allow(null) password: Joi.string().min(8).max(128).allow('').allow(null),
}; };
return schema.validate(user); return schema.validate(user);

View file

@ -16,5 +16,5 @@ module.exports = {
getPreset, getPreset,
getPresets, getPresets,
savePreset, savePreset,
deletePresets deletePresets,
}; };

View file

@ -68,8 +68,8 @@ const createMeiliMongooseModel = function ({ index, indexName, client, attribute
function (results, value, key) { function (results, value, key) {
return { ...results, [key]: 1 }; return { ...results, [key]: 1 };
}, },
{ _id: 1 } { _id: 1 },
) ),
); );
// Add additional data from mongodb into Meili search hits // Add additional data from mongodb into Meili search hits
@ -80,7 +80,7 @@ const createMeiliMongooseModel = function ({ index, indexName, client, attribute
return { return {
...(originalHit ? originalHit.toJSON() : {}), ...(originalHit ? originalHit.toJSON() : {}),
...hit ...hit,
}; };
}); });
data.hits = populatedHits; data.hits = populatedHits;
@ -161,8 +161,8 @@ module.exports = function mongoMeili(schema, options) {
type: Boolean, type: Boolean,
required: false, required: false,
select: false, select: false,
default: false default: false,
} },
}); });
const { host, apiKey, indexName, primaryKey } = options; const { host, apiKey, indexName, primaryKey } = options;
@ -183,8 +183,8 @@ module.exports = function mongoMeili(schema, options) {
return value.meiliIndex ? [...results, key] : results; return value.meiliIndex ? [...results, key] : results;
// }, []), '_id']; // }, []), '_id'];
}, },
[] [],
) ),
]; ];
schema.loadClass(createMeiliMongooseModel({ index, indexName, client, attributesToIndex })); schema.loadClass(createMeiliMongooseModel({ index, indexName, client, attributesToIndex }));

View file

@ -8,48 +8,48 @@ const convoSchema = mongoose.Schema(
unique: true, unique: true,
required: true, required: true,
index: true, index: true,
meiliIndex: true meiliIndex: true,
}, },
title: { title: {
type: String, type: String,
default: 'New Chat', default: 'New Chat',
meiliIndex: true meiliIndex: true,
}, },
user: { user: {
type: String, type: String,
default: null default: null,
}, },
messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }], messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }],
// google only // google only
examples: [{ type: mongoose.Schema.Types.Mixed }], examples: [{ type: mongoose.Schema.Types.Mixed }],
agentOptions: { agentOptions: {
type: mongoose.Schema.Types.Mixed, type: mongoose.Schema.Types.Mixed,
default: null default: null,
}, },
...conversationPreset, ...conversationPreset,
// for bingAI only // for bingAI only
bingConversationId: { bingConversationId: {
type: String, type: String,
default: null default: null,
}, },
jailbreakConversationId: { jailbreakConversationId: {
type: String, type: String,
default: null default: null,
}, },
conversationSignature: { conversationSignature: {
type: String, type: String,
default: null default: null,
}, },
clientId: { clientId: {
type: String, type: String,
default: null default: null,
}, },
invocationId: { invocationId: {
type: Number, type: Number,
default: 1 default: 1,
} },
}, },
{ timestamps: true } { timestamps: true },
); );
if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
@ -57,7 +57,7 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
host: process.env.MEILI_HOST, host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY, apiKey: process.env.MEILI_MASTER_KEY,
indexName: 'convos', // Will get created automatically if it doesn't exist already indexName: 'convos', // Will get created automatically if it doesn't exist already
primaryKey: 'conversationId' primaryKey: 'conversationId',
}); });
} }

View file

@ -3,156 +3,156 @@ const conversationPreset = {
endpoint: { endpoint: {
type: String, type: String,
default: null, default: null,
required: true required: true,
}, },
// for azureOpenAI, openAI, chatGPTBrowser only // for azureOpenAI, openAI, chatGPTBrowser only
model: { model: {
type: String, type: String,
default: null, default: null,
required: false required: false,
}, },
// for azureOpenAI, openAI only // for azureOpenAI, openAI only
chatGptLabel: { chatGptLabel: {
type: String, type: String,
default: null, default: null,
required: false required: false,
}, },
// for google only // for google only
modelLabel: { modelLabel: {
type: String, type: String,
default: null, default: null,
required: false required: false,
}, },
promptPrefix: { promptPrefix: {
type: String, type: String,
default: null, default: null,
required: false required: false,
}, },
temperature: { temperature: {
type: Number, type: Number,
default: 1, default: 1,
required: false required: false,
}, },
top_p: { top_p: {
type: Number, type: Number,
default: 1, default: 1,
required: false required: false,
}, },
// for google only // for google only
topP: { topP: {
type: Number, type: Number,
default: 0.95, default: 0.95,
required: false required: false,
}, },
topK: { topK: {
type: Number, type: Number,
default: 40, default: 40,
required: false required: false,
}, },
maxOutputTokens: { maxOutputTokens: {
type: Number, type: Number,
default: 1024, default: 1024,
required: false required: false,
}, },
presence_penalty: { presence_penalty: {
type: Number, type: Number,
default: 0, default: 0,
required: false required: false,
}, },
frequency_penalty: { frequency_penalty: {
type: Number, type: Number,
default: 0, default: 0,
required: false required: false,
}, },
// for bingai only // for bingai only
jailbreak: { jailbreak: {
type: Boolean, type: Boolean,
default: false default: false,
}, },
context: { context: {
type: String, type: String,
default: null default: null,
}, },
systemMessage: { systemMessage: {
type: String, type: String,
default: null default: null,
}, },
toneStyle: { toneStyle: {
type: String, type: String,
default: null default: null,
} },
}; };
const agentOptions = { const agentOptions = {
model: { model: {
type: String, type: String,
default: null, default: null,
required: false required: false,
}, },
// for azureOpenAI, openAI only // for azureOpenAI, openAI only
chatGptLabel: { chatGptLabel: {
type: String, type: String,
default: null, default: null,
required: false required: false,
}, },
// for google only // for google only
modelLabel: { modelLabel: {
type: String, type: String,
default: null, default: null,
required: false required: false,
}, },
promptPrefix: { promptPrefix: {
type: String, type: String,
default: null, default: null,
required: false required: false,
}, },
temperature: { temperature: {
type: Number, type: Number,
default: 1, default: 1,
required: false required: false,
}, },
top_p: { top_p: {
type: Number, type: Number,
default: 1, default: 1,
required: false required: false,
}, },
// for google only // for google only
topP: { topP: {
type: Number, type: Number,
default: 0.95, default: 0.95,
required: false required: false,
}, },
topK: { topK: {
type: Number, type: Number,
default: 40, default: 40,
required: false required: false,
}, },
maxOutputTokens: { maxOutputTokens: {
type: Number, type: Number,
default: 1024, default: 1024,
required: false required: false,
}, },
presence_penalty: { presence_penalty: {
type: Number, type: Number,
default: 0, default: 0,
required: false required: false,
}, },
frequency_penalty: { frequency_penalty: {
type: Number, type: Number,
default: 0, default: 0,
required: false required: false,
}, },
context: { context: {
type: String, type: String,
default: null default: null,
}, },
systemMessage: { systemMessage: {
type: String, type: String,
default: null default: null,
} },
}; };
module.exports = { module.exports = {
conversationPreset, conversationPreset,
agentOptions agentOptions,
}; };

View file

@ -7,88 +7,88 @@ const messageSchema = mongoose.Schema(
unique: true, unique: true,
required: true, required: true,
index: true, index: true,
meiliIndex: true meiliIndex: true,
}, },
conversationId: { conversationId: {
type: String, type: String,
required: true, required: true,
meiliIndex: true meiliIndex: true,
}, },
model: { model: {
type: String type: String,
}, },
conversationSignature: { conversationSignature: {
type: String type: String,
// required: true // required: true
}, },
clientId: { clientId: {
type: String type: String,
}, },
invocationId: { invocationId: {
type: String type: String,
}, },
parentMessageId: { parentMessageId: {
type: String type: String,
// required: true // required: true
}, },
tokenCount: { tokenCount: {
type: Number type: Number,
}, },
refinedTokenCount: { refinedTokenCount: {
type: Number type: Number,
}, },
sender: { sender: {
type: String, type: String,
required: true, required: true,
meiliIndex: true meiliIndex: true,
}, },
text: { text: {
type: String, type: String,
required: true, required: true,
meiliIndex: true meiliIndex: true,
}, },
refinedMessageText: { refinedMessageText: {
type: String type: String,
}, },
isCreatedByUser: { isCreatedByUser: {
type: Boolean, type: Boolean,
required: true, required: true,
default: false default: false,
}, },
unfinished: { unfinished: {
type: Boolean, type: Boolean,
default: false default: false,
}, },
cancelled: { cancelled: {
type: Boolean, type: Boolean,
default: false default: false,
}, },
error: { error: {
type: Boolean, type: Boolean,
default: false default: false,
}, },
_meiliIndex: { _meiliIndex: {
type: Boolean, type: Boolean,
required: false, required: false,
select: false, select: false,
default: false default: false,
}, },
plugin: { plugin: {
latest: { latest: {
type: String, type: String,
required: false required: false,
}, },
inputs: { inputs: {
type: [mongoose.Schema.Types.Mixed], type: [mongoose.Schema.Types.Mixed],
required: false required: false,
}, },
outputs: { outputs: {
type: String, type: String,
required: false required: false,
} },
} },
}, },
{ timestamps: true } { timestamps: true },
); );
if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
@ -96,7 +96,7 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
host: process.env.MEILI_HOST, host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY, apiKey: process.env.MEILI_MASTER_KEY,
indexName: 'messages', indexName: 'messages',
primaryKey: 'messageId' primaryKey: 'messageId',
}); });
} }

View file

@ -8,17 +8,17 @@ const pluginAuthSchema = mongoose.Schema(
}, },
value: { value: {
type: String, type: String,
required: true required: true,
}, },
userId: { userId: {
type: String, type: String,
required: true required: true,
}, },
pluginKey: { pluginKey: {
type: String, type: String,
} },
}, },
{ timestamps: true } { timestamps: true },
); );
const PluginAuth = mongoose.models.Plugin || mongoose.model('PluginAuth', pluginAuthSchema); const PluginAuth = mongoose.models.Plugin || mongoose.model('PluginAuth', pluginAuthSchema);

View file

@ -6,26 +6,26 @@ const presetSchema = mongoose.Schema(
type: String, type: String,
unique: true, unique: true,
required: true, required: true,
index: true index: true,
}, },
title: { title: {
type: String, type: String,
default: 'New Chat', default: 'New Chat',
meiliIndex: true meiliIndex: true,
}, },
user: { user: {
type: String, type: String,
default: null default: null,
}, },
// google only // google only
examples: [{ type: mongoose.Schema.Types.Mixed }], examples: [{ type: mongoose.Schema.Types.Mixed }],
...conversationPreset, ...conversationPreset,
agentOptions: { agentOptions: {
type: mongoose.Schema.Types.Mixed, type: mongoose.Schema.Types.Mixed,
default: null default: null,
} },
}, },
{ timestamps: true } { timestamps: true },
); );
const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema); const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema);

View file

@ -5,18 +5,18 @@ const tokenSchema = new Schema({
userId: { userId: {
type: Schema.Types.ObjectId, type: Schema.Types.ObjectId,
required: true, required: true,
ref: 'user' ref: 'user',
}, },
token: { token: {
type: String, type: String,
required: true required: true,
}, },
createdAt: { createdAt: {
type: Date, type: Date,
required: true, required: true,
default: Date.now, default: Date.now,
expires: 900 expires: 900,
} },
}); });
module.exports = mongoose.model('Token', tokenSchema); module.exports = mongoose.model('Token', tokenSchema);

View file

@ -1,7 +1,7 @@
const { const {
registerUser, registerUser,
requestPasswordReset, requestPasswordReset,
resetPassword resetPassword,
} = require('../services/auth.service'); } = require('../services/auth.service');
const isProduction = process.env.NODE_ENV === 'production'; const isProduction = process.env.NODE_ENV === 'production';
@ -16,7 +16,7 @@ const registrationController = async (req, res) => {
res.cookie('token', token, { res.cookie('token', token, {
expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)), expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)),
httpOnly: false, httpOnly: false,
secure: isProduction secure: isProduction,
}); });
res.status(status).send({ user }); res.status(status).send({ user });
} else { } else {
@ -52,7 +52,7 @@ const resetPasswordController = async (req, res) => {
const resetPasswordService = await resetPassword( const resetPasswordService = await resetPassword(
req.body.userId, req.body.userId,
req.body.token, req.body.token,
req.body.password req.body.password,
); );
if (resetPasswordService instanceof Error) { if (resetPasswordService instanceof Error) {
return res.status(400).json(resetPasswordService); return res.status(400).json(resetPasswordService);
@ -120,5 +120,5 @@ module.exports = {
// refreshController, // refreshController,
registrationController, registrationController,
resetPasswordRequestController, resetPasswordRequestController,
resetPasswordController resetPasswordController,
}; };

View file

@ -45,7 +45,7 @@ const getAvailablePluginsController = async (req, res) => {
}); });
res.status(200).json(authenticatedPlugins); res.status(200).json(authenticatedPlugins);
} }
} },
); );
} catch (error) { } catch (error) {
res.status(500).json({ message: error.message }); res.status(500).json({ message: error.message });
@ -53,5 +53,5 @@ const getAvailablePluginsController = async (req, res) => {
}; };
module.exports = { module.exports = {
getAvailablePluginsController getAvailablePluginsController,
}; };

View file

@ -51,5 +51,5 @@ const updateUserPluginsController = async (req, res) => {
module.exports = { module.exports = {
getUserController, getUserController,
updateUserPluginsController updateUserPluginsController,
}; };

View file

@ -3,7 +3,7 @@ const User = require('../../../models/User');
const loginController = async (req, res) => { const loginController = async (req, res) => {
try { try {
const user = await User.findById( const user = await User.findById(
req.user._id req.user._id,
); );
// If user doesn't exist, return error // If user doesn't exist, return error
@ -13,7 +13,7 @@ const loginController = async (req, res) => {
const token = req.user.generateToken(); const token = req.user.generateToken();
const expires = eval(process.env.SESSION_EXPIRY); const expires = eval(process.env.SESSION_EXPIRY);
// Add token to cookie // Add token to cookie
res.cookie( res.cookie(
'token', 'token',
@ -21,8 +21,8 @@ const loginController = async (req, res) => {
{ {
expires: new Date(Date.now() + expires), expires: new Date(Date.now() + expires),
httpOnly: false, httpOnly: false,
secure: process.env.NODE_ENV === 'production' secure: process.env.NODE_ENV === 'production',
} },
); );
return res.status(200).send({ token, user }); return res.status(200).send({ token, user });
@ -35,5 +35,5 @@ const loginController = async (req, res) => {
}; };
module.exports = { module.exports = {
loginController loginController,
}; };

View file

@ -17,5 +17,5 @@ const logoutController = async (req, res) => {
}; };
module.exports = { module.exports = {
logoutController logoutController,
}; };

View file

@ -58,7 +58,7 @@ config.validate(); // Validate the config
app.use(session({ app.use(session({
secret: process.env.OPENID_SESSION_SECRET, secret: process.env.OPENID_SESSION_SECRET,
resave: false, resave: false,
saveUninitialized: false saveUninitialized: false,
})); }));
app.use(passport.session()); app.use(passport.session());
require('../strategies/openidStrategy'); require('../strategies/openidStrategy');
@ -86,7 +86,7 @@ config.validate(); // Validate the config
app.listen(port, host, () => { app.listen(port, host, () => {
if (host == '0.0.0.0') if (host == '0.0.0.0')
console.log( console.log(
`Server listening on all interface at port ${port}. Use http://localhost:${port} to access it` `Server listening on all interface at port ${port}. Use http://localhost:${port} to access it`,
); );
else else
console.log(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`); console.log(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);

View file

@ -5,19 +5,19 @@ const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessa
try { try {
const conversationsCache = new Keyv({ const conversationsCache = new Keyv({
store: new KeyvFile({ filename: './data/cache.json' }), store: new KeyvFile({ filename: './data/cache.json' }),
namespace: 'chatgpt' // should be 'bing' for bing/sydney namespace: 'chatgpt', // should be 'bing' for bing/sydney
}); });
const { const {
conversationId, conversationId,
messageId: userMessageId, messageId: userMessageId,
parentMessageId: userParentMessageId, parentMessageId: userParentMessageId,
text: userText text: userText,
} = userMessage; } = userMessage;
const { const {
messageId: responseMessageId, messageId: responseMessageId,
parentMessageId: responseParentMessageId, parentMessageId: responseParentMessageId,
text: responseText text: responseText,
} = responseMessage; } = responseMessage;
let conversation = await conversationsCache.get(conversationId); let conversation = await conversationsCache.get(conversationId);
@ -26,7 +26,7 @@ const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessa
if (!conversation) { if (!conversation) {
conversation = { conversation = {
messages: [], messages: [],
createdAt: Date.now() createdAt: Date.now(),
}; };
// isNewConversation = true; // isNewConversation = true;
} }
@ -43,14 +43,14 @@ const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessa
id: userMessageId, id: userMessageId,
parentMessageId: userParentMessageId, parentMessageId: userParentMessageId,
role: 'User', role: 'User',
message: userText message: userText,
}; };
let _responseMessage = { let _responseMessage = {
id: responseMessageId, id: responseMessageId,
parentMessageId: responseParentMessageId, parentMessageId: responseParentMessageId,
role: roles(endpointOption), role: roles(endpointOption),
message: responseText message: responseText,
}; };
conversation.messages.push(_userMessage, _responseMessage); conversation.messages.push(_userMessage, _responseMessage);

View file

@ -27,8 +27,8 @@ router.post('/', requireJwtAuth, async (req, res) => {
temperature: req.body?.temperature ?? 0.7, temperature: req.body?.temperature ?? 0.7,
maxOutputTokens: req.body?.maxOutputTokens ?? 1024, maxOutputTokens: req.body?.maxOutputTokens ?? 1024,
topP: req.body?.topP ?? 0.7, topP: req.body?.topP ?? 0.7,
topK: req.body?.topK ?? 40 topK: req.body?.topK ?? 40,
} },
}; };
const conversationId = oldConversationId || crypto.randomUUID(); const conversationId = oldConversationId || crypto.randomUUID();
@ -39,7 +39,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
conversationId, conversationId,
parentMessageId, parentMessageId,
req, req,
res res,
}); });
}); });
@ -49,7 +49,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform', 'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no' 'X-Accel-Buffering': 'no',
}); });
let userMessage; let userMessage;
@ -81,10 +81,10 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
text: partialText, text: partialText,
unfinished: true, unfinished: true,
cancelled: false, cancelled: false,
error: false error: false,
}); });
} }
} },
}); });
const abortController = new AbortController(); const abortController = new AbortController();
@ -110,7 +110,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: responseMessage responseMessage: responseMessage,
}; };
}; };
@ -132,10 +132,10 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
onProgress: progressCallback.call(null, { onProgress: progressCallback.call(null, {
res, res,
text, text,
parentMessageId: overrideParentMessageId || userMessageId parentMessageId: overrideParentMessageId || userMessageId,
}), }),
onStart, onStart,
abortController abortController,
}); });
if (overrideParentMessageId) { if (overrideParentMessageId) {
@ -146,7 +146,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
...endpointOption, ...endpointOption,
...endpointOption.modelOptions, ...endpointOption.modelOptions,
conversationId, conversationId,
endpoint: 'anthropic' endpoint: 'anthropic',
}); });
await saveMessage(response); await saveMessage(response);
@ -155,7 +155,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: response responseMessage: response,
}); });
res.end(); res.end();
@ -163,7 +163,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
const title = await titleConvo({ text, response }); const title = await titleConvo({ text, response });
await saveConvo(req.user.id, { await saveConvo(req.user.id, {
conversationId, conversationId,
title title,
}); });
} }
} catch (error) { } catch (error) {
@ -176,7 +176,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
unfinished: false, unfinished: false,
cancelled: false, cancelled: false,
error: true, error: true,
text: error.message text: error.message,
}; };
await saveMessage(errorMessage); await saveMessage(errorMessage);
handleError(res, errorMessage); handleError(res, errorMessage);

View file

@ -13,7 +13,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
messageId, messageId,
overrideParentMessageId = null, overrideParentMessageId = null,
parentMessageId, parentMessageId,
conversationId: oldConversationId conversationId: oldConversationId,
} = req.body; } = req.body;
if (text.length === 0) return handleError(res, { text: 'Prompt empty or too short' }); if (text.length === 0) return handleError(res, { text: 'Prompt empty or too short' });
if (endpoint !== 'bingAI') return handleError(res, { text: 'Illegal request' }); if (endpoint !== 'bingAI') return handleError(res, { text: 'Illegal request' });
@ -29,7 +29,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
text, text,
parentMessageId: userParentMessageId, parentMessageId: userParentMessageId,
conversationId, conversationId,
isCreatedByUser: true isCreatedByUser: true,
}; };
// build endpoint option // build endpoint option
@ -41,7 +41,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
systemMessage: req.body?.systemMessage ?? null, systemMessage: req.body?.systemMessage ?? null,
context: req.body?.context ?? null, context: req.body?.context ?? null,
toneStyle: req.body?.toneStyle ?? 'creative', toneStyle: req.body?.toneStyle ?? 'creative',
token: req.body?.token ?? null token: req.body?.token ?? null,
}; };
else else
endpointOption = { endpointOption = {
@ -52,13 +52,13 @@ router.post('/', requireJwtAuth, async (req, res) => {
clientId: req.body?.clientId ?? null, clientId: req.body?.clientId ?? null,
invocationId: req.body?.invocationId ?? null, invocationId: req.body?.invocationId ?? null,
toneStyle: req.body?.toneStyle ?? 'creative', toneStyle: req.body?.toneStyle ?? 'creative',
token: req.body?.token ?? null token: req.body?.token ?? null,
}; };
console.log('ask log', { console.log('ask log', {
userMessage, userMessage,
endpointOption, endpointOption,
conversationId conversationId,
}); });
if (!overrideParentMessageId) { if (!overrideParentMessageId) {
@ -67,7 +67,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
...userMessage, ...userMessage,
...endpointOption, ...endpointOption,
conversationId, conversationId,
endpoint endpoint,
}); });
} }
@ -80,7 +80,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
preSendRequest: true, preSendRequest: true,
overrideParentMessageId, overrideParentMessageId,
req, req,
res res,
}); });
}); });
@ -92,7 +92,7 @@ const ask = async ({
preSendRequest = true, preSendRequest = true,
overrideParentMessageId = null, overrideParentMessageId = null,
req, req,
res res,
}) => { }) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
@ -103,7 +103,7 @@ const ask = async ({
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform', 'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no' 'X-Accel-Buffering': 'no',
}); });
if (preSendRequest) sendMessage(res, { message: userMessage, created: true }); if (preSendRequest) sendMessage(res, { message: userMessage, created: true });
@ -123,10 +123,10 @@ const ask = async ({
text: text, text: text,
unfinished: true, unfinished: true,
cancelled: false, cancelled: false,
error: false error: false,
}); });
} }
} },
}); });
const abortController = new AbortController(); const abortController = new AbortController();
let bingConversationId = null; let bingConversationId = null;
@ -142,9 +142,9 @@ const ask = async ({
onProgress: progressCallback.call(null, { onProgress: progressCallback.call(null, {
res, res,
text, text,
parentMessageId: overrideParentMessageId || userMessageId parentMessageId: overrideParentMessageId || userMessageId,
}), }),
abortController abortController,
}); });
console.log('BING RESPONSE', response); console.log('BING RESPONSE', response);
@ -173,7 +173,7 @@ const ask = async ({
response.details.suggestedResponses.map((s) => s.text), response.details.suggestedResponses.map((s) => s.text),
unfinished: false, unfinished: false,
cancelled: false, cancelled: false,
error: false error: false,
}; };
await saveMessage(responseMessage); await saveMessage(responseMessage);
@ -199,7 +199,7 @@ const ask = async ({
await saveMessage({ await saveMessage({
...userMessage, ...userMessage,
messageId: userMessageId, messageId: userMessageId,
newMessageId: newUserMessageId newMessageId: newUserMessageId,
}); });
userMessageId = newUserMessageId; userMessageId = newUserMessageId;
@ -208,19 +208,19 @@ const ask = async ({
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: responseMessage responseMessage: responseMessage,
}); });
res.end(); res.end();
if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { if (userParentMessageId == '00000000-0000-0000-0000-000000000000') {
const title = await titleConvoBing({ const title = await titleConvoBing({
text, text,
response: responseMessage response: responseMessage,
}); });
await saveConvo(req.user.id, { await saveConvo(req.user.id, {
conversationId: conversationId, conversationId: conversationId,
title title,
}); });
} }
} catch (error) { } catch (error) {
@ -233,7 +233,7 @@ const ask = async ({
unfinished: false, unfinished: false,
cancelled: false, cancelled: false,
error: true, error: true,
text: error.message text: error.message,
}; };
await saveMessage(errorMessage); await saveMessage(errorMessage);
handleError(res, errorMessage); handleError(res, errorMessage);

View file

@ -13,7 +13,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
text, text,
overrideParentMessageId = null, overrideParentMessageId = null,
parentMessageId, parentMessageId,
conversationId: oldConversationId conversationId: oldConversationId,
} = req.body; } = req.body;
if (text.length === 0) return handleError(res, { text: 'Prompt empty or too short' }); if (text.length === 0) return handleError(res, { text: 'Prompt empty or too short' });
if (endpoint !== 'chatGPTBrowser') return handleError(res, { text: 'Illegal request' }); if (endpoint !== 'chatGPTBrowser') return handleError(res, { text: 'Illegal request' });
@ -29,13 +29,13 @@ router.post('/', requireJwtAuth, async (req, res) => {
text, text,
parentMessageId: userParentMessageId, parentMessageId: userParentMessageId,
conversationId, conversationId,
isCreatedByUser: true isCreatedByUser: true,
}; };
// build endpoint option // build endpoint option
const endpointOption = { const endpointOption = {
model: req.body?.model ?? 'text-davinci-002-render-sha', model: req.body?.model ?? 'text-davinci-002-render-sha',
token: req.body?.token ?? null token: req.body?.token ?? null,
}; };
// const availableModels = getChatGPTBrowserModels(); // const availableModels = getChatGPTBrowserModels();
@ -45,7 +45,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
console.log('ask log', { console.log('ask log', {
userMessage, userMessage,
endpointOption, endpointOption,
conversationId conversationId,
}); });
if (!overrideParentMessageId) { if (!overrideParentMessageId) {
@ -54,7 +54,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
...userMessage, ...userMessage,
...endpointOption, ...endpointOption,
conversationId, conversationId,
endpoint endpoint,
}); });
} }
@ -67,7 +67,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
preSendRequest: true, preSendRequest: true,
overrideParentMessageId, overrideParentMessageId,
req, req,
res res,
}); });
}); });
@ -78,7 +78,7 @@ const ask = async ({
conversationId, conversationId,
overrideParentMessageId = null, overrideParentMessageId = null,
req, req,
res res,
}) => { }) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
const userId = req.user.id; const userId = req.user.id;
@ -88,7 +88,7 @@ const ask = async ({
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform', 'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no' 'X-Accel-Buffering': 'no',
}); });
let responseMessageId = crypto.randomUUID(); let responseMessageId = crypto.randomUUID();
@ -108,10 +108,10 @@ const ask = async ({
text: text, text: text,
unfinished: true, unfinished: true,
cancelled: false, cancelled: false,
error: false error: false,
}); });
} }
} },
}); });
getPartialMessage = getPartialText; getPartialMessage = getPartialText;
@ -134,9 +134,9 @@ const ask = async ({
sendMessage(res, { sendMessage(res, {
message: { ...userMessage, conversationId: data.conversation_id }, message: { ...userMessage, conversationId: data.conversation_id },
created: true created: true,
}); });
} },
}); });
console.log('CLIENT RESPONSE', response); console.log('CLIENT RESPONSE', response);
@ -157,7 +157,7 @@ const ask = async ({
sender: endpointOption?.chatGptLabel || 'ChatGPT', sender: endpointOption?.chatGptLabel || 'ChatGPT',
unfinished: false, unfinished: false,
cancelled: false, cancelled: false,
error: false error: false,
}; };
await saveMessage(responseMessage); await saveMessage(responseMessage);
@ -173,13 +173,13 @@ const ask = async ({
conversationUpdate = { conversationUpdate = {
...conversationUpdate, ...conversationUpdate,
conversationId: conversationId, conversationId: conversationId,
newConversationId: newConversationId newConversationId: newConversationId,
}; };
} else { } else {
// create new conversation // create new conversation
conversationUpdate = { conversationUpdate = {
...conversationUpdate, ...conversationUpdate,
...endpointOption ...endpointOption,
}; };
} }
@ -195,7 +195,7 @@ const ask = async ({
await saveMessage({ await saveMessage({
...userMessage, ...userMessage,
messageId: userMessageId, messageId: userMessageId,
newMessageId: newUserMassageId newMessageId: newUserMassageId,
}); });
userMessageId = newUserMassageId; userMessageId = newUserMassageId;
@ -204,7 +204,7 @@ const ask = async ({
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: responseMessage responseMessage: responseMessage,
}); });
res.end(); res.end();
@ -213,7 +213,7 @@ const ask = async ({
const title = await response.details.title; const title = await response.details.title;
await saveConvo(req.user.id, { await saveConvo(req.user.id, {
conversationId: conversationId, conversationId: conversationId,
title title,
}); });
} }
} catch (error) { } catch (error) {
@ -225,7 +225,7 @@ const ask = async ({
unfinished: false, unfinished: false,
cancelled: false, cancelled: false,
// error: true, // error: true,
text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"` text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`,
}; };
await saveMessage(errorMessage); await saveMessage(errorMessage);
handleError(res, errorMessage); handleError(res, errorMessage);

View file

@ -23,13 +23,13 @@ router.post('/', requireJwtAuth, async (req, res) => {
temperature: req.body?.temperature ?? 0.2, temperature: req.body?.temperature ?? 0.2,
maxOutputTokens: req.body?.maxOutputTokens ?? 1024, maxOutputTokens: req.body?.maxOutputTokens ?? 1024,
topP: req.body?.topP ?? 0.95, topP: req.body?.topP ?? 0.95,
topK: req.body?.topK ?? 40 topK: req.body?.topK ?? 40,
} },
}; };
const availableModels = ['chat-bison', 'text-bison', 'codechat-bison']; const availableModels = ['chat-bison', 'text-bison', 'codechat-bison'];
if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) { if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) {
return handleError(res, { text: `Illegal request: model` }); return handleError(res, { text: 'Illegal request: model' });
} }
const conversationId = oldConversationId || crypto.randomUUID(); const conversationId = oldConversationId || crypto.randomUUID();
@ -41,7 +41,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
conversationId, conversationId,
parentMessageId, parentMessageId,
req, req,
res res,
}); });
}); });
@ -51,7 +51,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform', 'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no' 'X-Accel-Buffering': 'no',
}); });
let userMessage; let userMessage;
let userMessageId; let userMessageId;
@ -84,10 +84,10 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
text: partialText, text: partialText,
unfinished: true, unfinished: true,
cancelled: false, cancelled: false,
error: false error: false,
}); });
} }
} },
}); });
const abortController = new AbortController(); const abortController = new AbortController();
@ -104,14 +104,14 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
key = require('../../../data/auth.json'); key = require('../../../data/auth.json');
} }
} catch (e) { } catch (e) {
console.log("No 'auth.json' file (service account key) found in /api/data/ for PaLM models"); console.log('No \'auth.json\' file (service account key) found in /api/data/ for PaLM models');
} }
const clientOptions = { const clientOptions = {
// debug: true, // for testing // debug: true, // for testing
reverseProxyUrl: process.env.GOOGLE_REVERSE_PROXY || null, reverseProxyUrl: process.env.GOOGLE_REVERSE_PROXY || null,
proxy: process.env.PROXY || null, proxy: process.env.PROXY || null,
...endpointOption ...endpointOption,
}; };
const client = new GoogleClient(key, clientOptions); const client = new GoogleClient(key, clientOptions);
@ -125,9 +125,9 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
onProgress: progressCallback.call(null, { onProgress: progressCallback.call(null, {
res, res,
text, text,
parentMessageId: overrideParentMessageId || userMessageId parentMessageId: overrideParentMessageId || userMessageId,
}), }),
abortController abortController,
}); });
if (overrideParentMessageId) { if (overrideParentMessageId) {
@ -138,7 +138,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
...endpointOption, ...endpointOption,
...endpointOption.modelOptions, ...endpointOption.modelOptions,
conversationId, conversationId,
endpoint: 'google' endpoint: 'google',
}); });
await saveMessage(response); await saveMessage(response);
@ -147,7 +147,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: response responseMessage: response,
}); });
res.end(); res.end();
@ -155,7 +155,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
const title = await titleConvo({ text, response }); const title = await titleConvo({ text, response });
await saveConvo(req.user.id, { await saveConvo(req.user.id, {
conversationId, conversationId,
title title,
}); });
} }
} catch (error) { } catch (error) {
@ -168,7 +168,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
unfinished: false, unfinished: false,
cancelled: false, cancelled: false,
error: true, error: true,
text: error.message text: error.message,
}; };
await saveMessage(errorMessage); await saveMessage(errorMessage);
handleError(res, errorMessage); handleError(res, errorMessage);

View file

@ -8,7 +8,7 @@ const {
sendMessage, sendMessage,
createOnProgress, createOnProgress,
formatSteps, formatSteps,
formatAction formatAction,
} = require('./handlers'); } = require('./handlers');
const requireJwtAuth = require('../../../middleware/requireJwtAuth'); const requireJwtAuth = require('../../../middleware/requireJwtAuth');
@ -44,12 +44,12 @@ router.post('/', requireJwtAuth, async (req, res) => {
temperature: req.body?.temperature ?? 0, temperature: req.body?.temperature ?? 0,
top_p: req.body?.top_p ?? 1, top_p: req.body?.top_p ?? 1,
presence_penalty: req.body?.presence_penalty ?? 0, presence_penalty: req.body?.presence_penalty ?? 0,
frequency_penalty: req.body?.frequency_penalty ?? 0 frequency_penalty: req.body?.frequency_penalty ?? 0,
}, },
agentOptions: { agentOptions: {
...agentOptions, ...agentOptions,
// agent: 'functions' // agent: 'functions'
} },
}; };
console.log('ask log'); console.log('ask log');
@ -63,7 +63,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
conversationId, conversationId,
parentMessageId, parentMessageId,
req, req,
res res,
}); });
}); });
@ -73,7 +73,7 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform', 'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no' 'X-Accel-Buffering': 'no',
}); });
let userMessage; let userMessage;
let userMessageId; let userMessageId;
@ -87,7 +87,7 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
loading: true, loading: true,
inputs: [], inputs: [],
latest: null, latest: null,
outputs: null outputs: null,
}; };
try { try {
@ -119,10 +119,10 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
model: endpointOption.modelOptions.model, model: endpointOption.modelOptions.model,
unfinished: true, unfinished: true,
cancelled: false, cancelled: false,
error: false error: false,
}); });
} }
} },
}); });
const abortController = new AbortController(); const abortController = new AbortController();
@ -149,7 +149,7 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: responseMessage responseMessage: responseMessage,
}; };
}; };
@ -164,7 +164,7 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
endpoint, endpoint,
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null, reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
proxy: process.env.PROXY || null, proxy: process.env.PROXY || null,
...endpointOption ...endpointOption,
}; };
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY; let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
@ -211,9 +211,9 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
res, res,
text, text,
plugin, plugin,
parentMessageId: overrideParentMessageId || userMessageId parentMessageId: overrideParentMessageId || userMessageId,
}), }),
abortController abortController,
}); });
if (overrideParentMessageId) { if (overrideParentMessageId) {
@ -230,7 +230,7 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: response responseMessage: response,
}); });
res.end(); res.end();
@ -243,7 +243,7 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
}); });
await saveConvo(req.user.id, { await saveConvo(req.user.id, {
conversationId: conversationId, conversationId: conversationId,
title title,
}); });
} }
} catch (error) { } catch (error) {
@ -256,7 +256,7 @@ const ask = async ({ text, endpoint, endpointOption, parentMessageId = null, con
unfinished: false, unfinished: false,
cancelled: false, cancelled: false,
error: true, error: true,
text: error.message text: error.message,
}; };
await saveMessage(errorMessage); await saveMessage(errorMessage);
handleError(res, errorMessage); handleError(res, errorMessage);

View file

@ -134,7 +134,7 @@ function formatAction(action) {
input: getString(action.toolInput), input: getString(action.toolInput),
thought: action.log.includes('Thought: ') thought: action.log.includes('Thought: ')
? action.log.split('\n')[0].replace('Thought: ', '') ? action.log.split('\n')[0].replace('Thought: ', '')
: action.log.split('\n')[0] : action.log.split('\n')[0],
}; };
formattedAction.thought = getString(formattedAction.thought); formattedAction.thought = getString(formattedAction.thought);
@ -161,5 +161,5 @@ module.exports = {
createOnProgress, createOnProgress,
handleText, handleText,
formatSteps, formatSteps,
formatAction formatAction,
}; };

View file

@ -31,8 +31,8 @@ router.post('/', requireJwtAuth, async (req, res) => {
temperature: req.body?.temperature ?? 1, temperature: req.body?.temperature ?? 1,
top_p: req.body?.top_p ?? 1, top_p: req.body?.top_p ?? 1,
presence_penalty: req.body?.presence_penalty ?? 0, presence_penalty: req.body?.presence_penalty ?? 0,
frequency_penalty: req.body?.frequency_penalty ?? 0 frequency_penalty: req.body?.frequency_penalty ?? 0,
} },
}; };
console.log('ask log'); console.log('ask log');
@ -46,7 +46,7 @@ router.post('/', requireJwtAuth, async (req, res) => {
parentMessageId, parentMessageId,
endpoint, endpoint,
req, req,
res res,
}); });
}); });
@ -56,7 +56,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, endpoint, con
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform', 'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no' 'X-Accel-Buffering': 'no',
}); });
let userMessage; let userMessage;
let userMessageId; let userMessageId;
@ -90,10 +90,10 @@ const ask = async ({ text, endpointOption, parentMessageId = null, endpoint, con
model: endpointOption.modelOptions.model, model: endpointOption.modelOptions.model,
unfinished: true, unfinished: true,
cancelled: false, cancelled: false,
error: false error: false,
}); });
} }
} },
}); });
const abortController = new AbortController(); const abortController = new AbortController();
@ -119,7 +119,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, endpoint, con
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: responseMessage responseMessage: responseMessage,
}; };
}; };
@ -135,7 +135,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, endpoint, con
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null, reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
proxy: process.env.PROXY || null, proxy: process.env.PROXY || null,
endpoint, endpoint,
...endpointOption ...endpointOption,
}; };
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY; let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
@ -157,9 +157,9 @@ const ask = async ({ text, endpointOption, parentMessageId = null, endpoint, con
onProgress: progressCallback.call(null, { onProgress: progressCallback.call(null, {
res, res,
text, text,
parentMessageId: overrideParentMessageId || userMessageId parentMessageId: overrideParentMessageId || userMessageId,
}), }),
abortController abortController,
}); });
if (overrideParentMessageId) { if (overrideParentMessageId) {
@ -174,7 +174,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, endpoint, con
final: true, final: true,
conversation: await getConvo(req.user.id, conversationId), conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: response responseMessage: response,
}); });
res.end(); res.end();
@ -183,11 +183,11 @@ const ask = async ({ text, endpointOption, parentMessageId = null, endpoint, con
text, text,
response, response,
openAIApiKey, openAIApiKey,
azure: endpoint === 'azureOpenAI' azure: endpoint === 'azureOpenAI',
}); });
await saveConvo(req.user.id, { await saveConvo(req.user.id, {
conversationId, conversationId,
title title,
}); });
} }
} catch (error) { } catch (error) {
@ -204,7 +204,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, endpoint, con
unfinished: false, unfinished: false,
cancelled: false, cancelled: false,
error: true, error: true,
text: error.message text: error.message,
}; };
await saveMessage(errorMessage); await saveMessage(errorMessage);
handleError(res, errorMessage); handleError(res, errorMessage);

View file

@ -3,7 +3,7 @@ const {
resetPasswordRequestController, resetPasswordRequestController,
resetPasswordController, resetPasswordController,
// refreshController, // refreshController,
registrationController registrationController,
} = require('../controllers/AuthController'); } = require('../controllers/AuthController');
const { loginController } = require('../controllers/auth/LoginController'); const { loginController } = require('../controllers/auth/LoginController');
const { logoutController } = require('../controllers/auth/LogoutController'); const { logoutController } = require('../controllers/auth/LogoutController');

View file

@ -5,9 +5,9 @@ router.get('/', async function (req, res) {
try { try {
const appTitle = process.env.APP_TITLE || 'LibreChat'; const appTitle = process.env.APP_TITLE || 'LibreChat';
const googleLoginEnabled = !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET; const googleLoginEnabled = !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET;
const openidLoginEnabled = !!process.env.OPENID_CLIENT_ID const openidLoginEnabled = !!process.env.OPENID_CLIENT_ID
&& !!process.env.OPENID_CLIENT_SECRET && !!process.env.OPENID_CLIENT_SECRET
&& !!process.env.OPENID_ISSUER && !!process.env.OPENID_ISSUER
&& !!process.env.OPENID_SESSION_SECRET; && !!process.env.OPENID_SESSION_SECRET;
const openidLabel = process.env.OPENID_BUTTON_LABEL || 'Login with OpenID'; const openidLabel = process.env.OPENID_BUTTON_LABEL || 'Login with OpenID';
const openidImageUrl = process.env.OPENID_IMAGE_URL; const openidImageUrl = process.env.OPENID_IMAGE_URL;
@ -16,7 +16,7 @@ router.get('/', async function (req, res) {
const serverDomain = process.env.DOMAIN_SERVER || 'http://localhost:3080'; const serverDomain = process.env.DOMAIN_SERVER || 'http://localhost:3080';
const registrationEnabled = process.env.ALLOW_REGISTRATION === 'true'; const registrationEnabled = process.env.ALLOW_REGISTRATION === 'true';
const socialLoginEnabled = process.env.ALLOW_SOCIAL_LOGIN === 'true'; const socialLoginEnabled = process.env.ALLOW_SOCIAL_LOGIN === 'true';
return res.status(200).send({ return res.status(200).send({
appTitle, appTitle,
googleLoginEnabled, googleLoginEnabled,
@ -27,12 +27,12 @@ router.get('/', async function (req, res) {
discordLoginEnabled, discordLoginEnabled,
serverDomain, serverDomain,
registrationEnabled, registrationEnabled,
socialLoginEnabled socialLoginEnabled,
}); });
} catch (err) { } catch (err) {
console.error(err); console.error(err);
return res.status(500).send({error: err.message}); return res.status(500).send({ error: err.message });
} }
}); });

View file

@ -72,13 +72,13 @@ router.get('/', async function (req, res) {
const chatGPTBrowser = process.env.CHATGPT_TOKEN const chatGPTBrowser = process.env.CHATGPT_TOKEN
? { ? {
userProvide: process.env.CHATGPT_TOKEN == 'user_provided', userProvide: process.env.CHATGPT_TOKEN == 'user_provided',
availableModels: getChatGPTBrowserModels() availableModels: getChatGPTBrowserModels(),
} }
: false; : false;
const anthropic = process.env.ANTHROPIC_API_KEY const anthropic = process.env.ANTHROPIC_API_KEY
? { ? {
userProvide: process.env.ANTHROPIC_API_KEY == 'user_provided', userProvide: process.env.ANTHROPIC_API_KEY == 'user_provided',
availableModels: getAnthropicModels() availableModels: getAnthropicModels(),
} }
: false; : false;

View file

@ -25,5 +25,5 @@ module.exports = {
tokenizer, tokenizer,
endpoints, endpoints,
plugins, plugins,
config config,
}; };

View file

@ -12,8 +12,8 @@ router.get(
'/google', '/google',
passport.authenticate('google', { passport.authenticate('google', {
scope: ['openid', 'profile', 'email'], scope: ['openid', 'profile', 'email'],
session: false session: false,
}) }),
); );
router.get( router.get(
@ -22,25 +22,25 @@ router.get(
failureRedirect: `${domains.client}/login`, failureRedirect: `${domains.client}/login`,
failureMessage: true, failureMessage: true,
session: false, session: false,
scope: ['openid', 'profile', 'email'] scope: ['openid', 'profile', 'email'],
}), }),
(req, res) => { (req, res) => {
const token = req.user.generateToken(); const token = req.user.generateToken();
res.cookie('token', token, { res.cookie('token', token, {
expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)), expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)),
httpOnly: false, httpOnly: false,
secure: isProduction secure: isProduction,
}); });
res.redirect(domains.client); res.redirect(domains.client);
} },
); );
router.get( router.get(
'/facebook', '/facebook',
passport.authenticate('facebook', { passport.authenticate('facebook', {
scope: ['public_profile', 'email'], scope: ['public_profile', 'email'],
session: false session: false,
}) }),
); );
router.get( router.get(
@ -49,24 +49,24 @@ router.get(
failureRedirect: `${domains.client}/login`, failureRedirect: `${domains.client}/login`,
failureMessage: true, failureMessage: true,
session: false, session: false,
scope: ['public_profile', 'email'] scope: ['public_profile', 'email'],
}), }),
(req, res) => { (req, res) => {
const token = req.user.generateToken(); const token = req.user.generateToken();
res.cookie('token', token, { res.cookie('token', token, {
expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)), expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)),
httpOnly: false, httpOnly: false,
secure: isProduction secure: isProduction,
}); });
res.redirect(domains.client); res.redirect(domains.client);
} },
); );
router.get( router.get(
'/openid', '/openid',
passport.authenticate('openid', { passport.authenticate('openid', {
session: false session: false,
}) }),
); );
router.get( router.get(
@ -74,26 +74,25 @@ router.get(
passport.authenticate('openid', { passport.authenticate('openid', {
failureRedirect: `${domains.client}/login`, failureRedirect: `${domains.client}/login`,
failureMessage: true, failureMessage: true,
session: false session: false,
}), }),
(req, res) => { (req, res) => {
const token = req.user.generateToken(); const token = req.user.generateToken();
res.cookie('token', token, { res.cookie('token', token, {
expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)), expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)),
httpOnly: false, httpOnly: false,
secure: isProduction secure: isProduction,
}); });
res.redirect(domains.client); res.redirect(domains.client);
} },
); );
router.get( router.get(
'/github', '/github',
passport.authenticate('github', { passport.authenticate('github', {
scope: ['user:email', 'read:user'], scope: ['user:email', 'read:user'],
session: false session: false,
}) }),
); );
router.get( router.get(
@ -102,26 +101,25 @@ router.get(
failureRedirect: `${domains.client}/login`, failureRedirect: `${domains.client}/login`,
failureMessage: true, failureMessage: true,
session: false, session: false,
scope: ['user:email', 'read:user'] scope: ['user:email', 'read:user'],
}), }),
(req, res) => { (req, res) => {
const token = req.user.generateToken(); const token = req.user.generateToken();
res.cookie('token', token, { res.cookie('token', token, {
expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)), expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)),
httpOnly: false, httpOnly: false,
secure: isProduction secure: isProduction,
}); });
res.redirect(domains.client); res.redirect(domains.client);
} },
); );
router.get( router.get(
'/discord', '/discord',
passport.authenticate('discord', { passport.authenticate('discord', {
scope: ['identify', 'email'], scope: ['identify', 'email'],
session: false session: false,
}) }),
); );
router.get( router.get(
@ -130,17 +128,17 @@ router.get(
failureRedirect: `${domains.client}/login`, failureRedirect: `${domains.client}/login`,
failureMessage: true, failureMessage: true,
session: false, session: false,
scope: ['identify', 'email'] scope: ['identify', 'email'],
}), }),
(req, res) => { (req, res) => {
const token = req.user.generateToken(); const token = req.user.generateToken();
res.cookie('token', token, { res.cookie('token', token, {
expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)), expires: new Date(Date.now() + eval(process.env.SESSION_EXPIRY)),
httpOnly: false, httpOnly: false,
secure: isProduction secure: isProduction,
}); });
res.redirect(domains.client); res.redirect(domains.client);
} },
); );
module.exports = router; module.exports = router;

View file

@ -1,6 +1,6 @@
const express = require('express'); const express = require('express');
const router = express.Router(); const router = express.Router();
const { savePrompt, getPrompts, deletePrompts } = require('../../models/Prompt'); const { getPrompts } = require('../../models/Prompt');
router.get('/', async (req, res) => { router.get('/', async (req, res) => {
let filter = {}; let filter = {};

View file

@ -42,16 +42,16 @@ router.get('/', requireJwtAuth, async function (req, res) {
{ {
attributesToHighlight: ['text'], attributesToHighlight: ['text'],
highlightPreTag: '**', highlightPreTag: '**',
highlightPostTag: '**' highlightPostTag: '**',
}, },
true true,
) )
).hits.map((message) => { ).hits.map((message) => {
const { _formatted, ...rest } = message; const { _formatted, ...rest } = message;
return { return {
...rest, ...rest,
searchResult: true, searchResult: true,
text: _formatted.text text: _formatted.text,
}; };
}); });
const titles = (await Conversation.meiliSearch(q)).hits; const titles = (await Conversation.meiliSearch(q)).hits;
@ -111,7 +111,7 @@ router.get('/enable', async function (req, res) {
try { try {
const client = new MeiliSearch({ const client = new MeiliSearch({
host: process.env.MEILI_HOST, host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY apiKey: process.env.MEILI_MASTER_KEY,
}); });
const { status } = await client.health(); const { status } = await client.health();

View file

@ -47,7 +47,7 @@ const updateUserPluginAuth = async (userId, authField, pluginKey, value) => {
if (pluginAuth) { if (pluginAuth) {
const pluginAuth = await PluginAuth.updateOne( const pluginAuth = await PluginAuth.updateOne(
{ userId, authField }, { userId, authField },
{ $set: { value: encryptedValue } } { $set: { value: encryptedValue } },
); );
return pluginAuth; return pluginAuth;
} else { } else {
@ -55,7 +55,7 @@ const updateUserPluginAuth = async (userId, authField, pluginKey, value) => {
userId, userId,
authField, authField,
value: encryptedValue, value: encryptedValue,
pluginKey pluginKey,
}); });
newPluginAuth.save(); newPluginAuth.save();
return newPluginAuth; return newPluginAuth;
@ -79,5 +79,5 @@ const deleteUserPluginAuth = async (userId, authField) => {
module.exports = { module.exports = {
getUserPluginAuthValue, getUserPluginAuthValue,
updateUserPluginAuth, updateUserPluginAuth,
deleteUserPluginAuth deleteUserPluginAuth,
}; };

View file

@ -5,13 +5,13 @@ const updateUserPluginsService = async (user, pluginKey, action) => {
if (action === 'install') { if (action === 'install') {
const response = await User.updateOne( const response = await User.updateOne(
{ _id: user._id }, { _id: user._id },
{ $set: { plugins: [...user.plugins, pluginKey] } } { $set: { plugins: [...user.plugins, pluginKey] } },
); );
return response; return response;
} else if (action === 'uninstall') { } else if (action === 'uninstall') {
const response = await User.updateOne( const response = await User.updateOne(
{ _id: user._id }, { _id: user._id },
{ $set: { plugins: user.plugins.filter((plugin) => plugin !== pluginKey) } } { $set: { plugins: user.plugins.filter((plugin) => plugin !== pluginKey) } },
); );
return response; return response;
} }

View file

@ -18,7 +18,7 @@ const logoutUser = async (user, refreshToken) => {
try { try {
const userFound = await User.findById(user._id); const userFound = await User.findById(user._id);
const tokenIndex = userFound.refreshToken.findIndex( const tokenIndex = userFound.refreshToken.findIndex(
(item) => item.refreshToken === refreshToken (item) => item.refreshToken === refreshToken,
); );
if (tokenIndex !== -1) { if (tokenIndex !== -1) {
@ -45,7 +45,7 @@ const registerUser = async (user) => {
console.info( console.info(
'Route: register - Joi Validation Error', 'Route: register - Joi Validation Error',
{ name: 'Request params:', value: user }, { name: 'Request params:', value: user },
{ name: 'Validation error:', value: error.details } { name: 'Validation error:', value: error.details },
); );
return { status: 422, message: error.details[0].message }; return { status: 422, message: error.details[0].message };
@ -60,7 +60,7 @@ const registerUser = async (user) => {
console.info( console.info(
'Register User - Email in use', 'Register User - Email in use',
{ name: 'Request params:', value: user }, { name: 'Request params:', value: user },
{ name: 'Existing user:', value: existingUser } { name: 'Existing user:', value: existingUser },
); );
// Sleep for 1 second // Sleep for 1 second
@ -80,7 +80,7 @@ const registerUser = async (user) => {
username, username,
name, name,
avatar: null, avatar: null,
role: isFirstRegisteredUser ? 'ADMIN' : 'USER' role: isFirstRegisteredUser ? 'ADMIN' : 'USER',
}); });
// todo: implement refresh token // todo: implement refresh token
@ -118,7 +118,7 @@ const requestPasswordReset = async (email) => {
await new Token({ await new Token({
userId: user._id, userId: user._id,
token: hash, token: hash,
createdAt: Date.now() createdAt: Date.now(),
}).save(); }).save();
const link = `${domains.client}/reset-password?token=${resetToken}&userId=${user._id}`; const link = `${domains.client}/reset-password?token=${resetToken}&userId=${user._id}`;
@ -128,9 +128,9 @@ const requestPasswordReset = async (email) => {
'Password Reset Request', 'Password Reset Request',
{ {
name: user.name, name: user.name,
link: link link: link,
}, },
'./template/requestResetPassword.handlebars' './template/requestResetPassword.handlebars',
); );
return { link }; return { link };
}; };
@ -166,9 +166,9 @@ const resetPassword = async (userId, token, password) => {
user.email, user.email,
'Password Reset Successfully', 'Password Reset Successfully',
{ {
name: user.name name: user.name,
}, },
'./template/resetPassword.handlebars' './template/resetPassword.handlebars',
); );
await passwordResetToken.deleteOne(); await passwordResetToken.deleteOne();
@ -180,5 +180,5 @@ module.exports = {
registerUser, registerUser,
logoutUser, logoutUser,
requestPasswordReset, requestPasswordReset,
resetPassword resetPassword,
}; };

View file

@ -10,7 +10,7 @@ const discordLogin = new DiscordStrategy(
clientSecret: process.env.DISCORD_CLIENT_SECRET, clientSecret: process.env.DISCORD_CLIENT_SECRET,
callbackURL: `${domains.server}${process.env.DISCORD_CALLBACK_URL}`, callbackURL: `${domains.server}${process.env.DISCORD_CALLBACK_URL}`,
scope: ['identify', 'email'], // Request scopes scope: ['identify', 'email'], // Request scopes
authorizationURL: 'https://discord.com/api/oauth2/authorize?prompt=none' // Add the prompt query parameter authorizationURL: 'https://discord.com/api/oauth2/authorize?prompt=none', // Add the prompt query parameter
}, },
async (accessToken, refreshToken, profile, cb) => { async (accessToken, refreshToken, profile, cb) => {
try { try {
@ -37,7 +37,7 @@ const discordLogin = new DiscordStrategy(
username: profile.username, username: profile.username,
email, email,
name: profile.global_name, name: profile.global_name,
avatar: avatarURL avatar: avatarURL,
}); });
cb(null, newUser); cb(null, newUser);
@ -45,7 +45,7 @@ const discordLogin = new DiscordStrategy(
console.error(err); console.error(err);
cb(err); cb(err);
} }
} },
); );
passport.use(discordLogin); passport.use(discordLogin);

View file

@ -10,7 +10,7 @@ const facebookLogin = new FacebookStrategy(
clientID: process.env.FACEBOOK_APP_ID, clientID: process.env.FACEBOOK_APP_ID,
clientSecret: process.env.FACEBOOK_SECRET, clientSecret: process.env.FACEBOOK_SECRET,
callbackURL: `${domains.server}${process.env.FACEBOOK_CALLBACK_URL}`, callbackURL: `${domains.server}${process.env.FACEBOOK_CALLBACK_URL}`,
proxy: true proxy: true,
// profileFields: [ // profileFields: [
// 'id', // 'id',
// 'email', // 'email',
@ -46,14 +46,14 @@ const facebookLogin = new FacebookStrategy(
username: profile.name.givenName + profile.name.familyName, username: profile.name.givenName + profile.name.familyName,
email: profile.emails[0].value, email: profile.emails[0].value,
name: profile.displayName, name: profile.displayName,
avatar: profile.photos[0].value avatar: profile.photos[0].value,
}).save(); }).save();
done(null, newUser); done(null, newUser);
} catch (err) { } catch (err) {
console.log(err); console.log(err);
} }
} },
); );
passport.use(facebookLogin); passport.use(facebookLogin);

View file

@ -12,7 +12,7 @@ const githubLogin = new GitHubStrategy(
clientSecret: process.env.GITHUB_CLIENT_SECRET, clientSecret: process.env.GITHUB_CLIENT_SECRET,
callbackURL: `${domains.server}${process.env.GITHUB_CALLBACK_URL}`, callbackURL: `${domains.server}${process.env.GITHUB_CALLBACK_URL}`,
proxy: false, proxy: false,
scope: ['user:email'] // Request email scope scope: ['user:email'], // Request email scope
}, },
async (accessToken, refreshToken, profile, cb) => { async (accessToken, refreshToken, profile, cb) => {
try { try {
@ -33,7 +33,7 @@ const githubLogin = new GitHubStrategy(
email, email,
emailVerified: profile.emails[0].verified, emailVerified: profile.emails[0].verified,
name: profile.displayName, name: profile.displayName,
avatar: profile.photos[0].value avatar: profile.photos[0].value,
}).save(); }).save();
cb(null, newUser); cb(null, newUser);
@ -41,7 +41,7 @@ const githubLogin = new GitHubStrategy(
console.error(err); console.error(err);
cb(err); cb(err);
} }
} },
); );
passport.use(githubLogin); passport.use(githubLogin);

View file

@ -11,7 +11,7 @@ const googleLogin = new GoogleStrategy(
clientID: process.env.GOOGLE_CLIENT_ID, clientID: process.env.GOOGLE_CLIENT_ID,
clientSecret: process.env.GOOGLE_CLIENT_SECRET, clientSecret: process.env.GOOGLE_CLIENT_SECRET,
callbackURL: `${domains.server}${process.env.GOOGLE_CALLBACK_URL}`, callbackURL: `${domains.server}${process.env.GOOGLE_CALLBACK_URL}`,
proxy: true proxy: true,
}, },
async (accessToken, refreshToken, profile, cb) => { async (accessToken, refreshToken, profile, cb) => {
try { try {
@ -31,13 +31,13 @@ const googleLogin = new GoogleStrategy(
email: profile.emails[0].value, email: profile.emails[0].value,
emailVerified: profile.emails[0].verified, emailVerified: profile.emails[0].verified,
name: `${profile.name.givenName} ${profile.name.familyName}`, name: `${profile.name.givenName} ${profile.name.familyName}`,
avatar: profile.photos[0].value avatar: profile.photos[0].value,
}).save(); }).save();
cb(null, newUser); cb(null, newUser);
} catch (err) { } catch (err) {
console.log(err); console.log(err);
} }
} },
); );
passport.use(googleLogin); passport.use(googleLogin);

View file

@ -6,7 +6,7 @@ const User = require('../models/User');
const jwtLogin = new JwtStrategy( const jwtLogin = new JwtStrategy(
{ {
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(), jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
secretOrKey: process.env.JWT_SECRET secretOrKey: process.env.JWT_SECRET,
}, },
async (payload, done) => { async (payload, done) => {
try { try {
@ -20,7 +20,7 @@ const jwtLogin = new JwtStrategy(
} catch (err) { } catch (err) {
done(err, false); done(err, false);
} }
} },
); );
passport.use(jwtLogin); passport.use(jwtLogin);

View file

@ -10,14 +10,14 @@ const passportLogin = new PassportLocalStrategy(
usernameField: 'email', usernameField: 'email',
passwordField: 'password', passwordField: 'password',
session: false, session: false,
passReqToCallback: true passReqToCallback: true,
}, },
async (req, email, password, done) => { async (req, email, password, done) => {
const { error } = loginSchema.validate(req.body); const { error } = loginSchema.validate(req.body);
if (error) { if (error) {
log({ log({
title: 'Passport Local Strategy - Validation Error', title: 'Passport Local Strategy - Validation Error',
parameters: [{ name: 'req.body', value: req.body }] parameters: [{ name: 'req.body', value: req.body }],
}); });
return done(null, false, { message: error.details[0].message }); return done(null, false, { message: error.details[0].message });
} }
@ -27,7 +27,7 @@ const passportLogin = new PassportLocalStrategy(
if (!user) { if (!user) {
log({ log({
title: 'Passport Local Strategy - User Not Found', title: 'Passport Local Strategy - User Not Found',
parameters: [{ name: 'email', value: email }] parameters: [{ name: 'email', value: email }],
}); });
return done(null, false, { message: 'Email does not exists.' }); return done(null, false, { message: 'Email does not exists.' });
} }
@ -36,14 +36,14 @@ const passportLogin = new PassportLocalStrategy(
if (err) { if (err) {
log({ log({
title: 'Passport Local Strategy - Compare password error', title: 'Passport Local Strategy - Compare password error',
parameters: [{ name: 'error', value: err }] parameters: [{ name: 'error', value: err }],
}); });
return done(err); return done(err);
} }
if (!isMatch) { if (!isMatch) {
log({ log({
title: 'Passport Local Strategy - Password does not match', title: 'Passport Local Strategy - Password does not match',
parameters: [{ name: 'isMatch', value: isMatch }] parameters: [{ name: 'isMatch', value: isMatch }],
}); });
return done(null, false, { message: 'Incorrect password.' }); return done(null, false, { message: 'Incorrect password.' });
} }
@ -53,7 +53,7 @@ const passportLogin = new PassportLocalStrategy(
} catch (err) { } catch (err) {
return done(err); return done(err);
} }
} },
); );
passport.use(passportLogin); passport.use(passportLogin);

View file

@ -1,5 +1,4 @@
const passport = require('passport'); const passport = require('passport');
const jwt = require('jsonwebtoken');
const { Issuer, Strategy: OpenIDStrategy } = require('openid-client'); const { Issuer, Strategy: OpenIDStrategy } = require('openid-client');
const axios = require('axios'); const axios = require('axios');
const fs = require('fs'); const fs = require('fs');
@ -20,11 +19,11 @@ const downloadImage = async (url, imagePath, accessToken) => {
try { try {
const response = await axios.get(url, { const response = await axios.get(url, {
headers: { headers: {
'Authorization': `Bearer ${accessToken}` 'Authorization': `Bearer ${accessToken}`,
}, },
responseType: 'arraybuffer' responseType: 'arraybuffer',
}); });
fs.mkdirSync(path.dirname(imagePath), { recursive: true }); fs.mkdirSync(path.dirname(imagePath), { recursive: true });
fs.writeFileSync(imagePath, response.data); fs.writeFileSync(imagePath, response.data);
@ -42,15 +41,15 @@ Issuer.discover(process.env.OPENID_ISSUER)
const client = new issuer.Client({ const client = new issuer.Client({
client_id: process.env.OPENID_CLIENT_ID, client_id: process.env.OPENID_CLIENT_ID,
client_secret: process.env.OPENID_CLIENT_SECRET, client_secret: process.env.OPENID_CLIENT_SECRET,
redirect_uris: [domains.server + process.env.OPENID_CALLBACK_URL] redirect_uris: [domains.server + process.env.OPENID_CALLBACK_URL],
}); });
const openidLogin = new OpenIDStrategy( const openidLogin = new OpenIDStrategy(
{ {
client, client,
params: { params: {
scope: process.env.OPENID_SCOPE scope: process.env.OPENID_SCOPE,
} },
}, },
async (tokenset, userinfo, done) => { async (tokenset, userinfo, done) => {
try { try {
@ -68,7 +67,7 @@ Issuer.discover(process.env.OPENID_ISSUER)
} else if (userinfo.family_name) { } else if (userinfo.family_name) {
fullName = userinfo.family_name; fullName = userinfo.family_name;
} }
if (!user) { if (!user) {
user = new User({ user = new User({
provider: 'openid', provider: 'openid',
@ -76,7 +75,7 @@ Issuer.discover(process.env.OPENID_ISSUER)
username: userinfo.given_name || '', username: userinfo.given_name || '',
email: userinfo.email || '', email: userinfo.email || '',
emailVerified: userinfo.email_verified || false, emailVerified: userinfo.email_verified || false,
name: fullName name: fullName,
}); });
} else { } else {
user.provider = 'openid'; user.provider = 'openid';
@ -105,14 +104,14 @@ Issuer.discover(process.env.OPENID_ISSUER)
} else { } else {
user.avatar = ''; user.avatar = '';
} }
await user.save(); await user.save();
done(null, user); done(null, user);
} catch (err) { } catch (err) {
done(err); done(err);
} }
} },
); );
passport.use('openid', openidLogin); passport.use('openid', openidLogin);

View file

@ -2,7 +2,7 @@ const Joi = require('joi');
const loginSchema = Joi.object().keys({ const loginSchema = Joi.object().keys({
email: Joi.string().trim().email().required(), email: Joi.string().trim().email().required(),
password: Joi.string().trim().min(8).max(128).required() password: Joi.string().trim().min(8).max(128).required(),
}); });
const registerSchema = Joi.object().keys({ const registerSchema = Joi.object().keys({
@ -15,10 +15,10 @@ const registerSchema = Joi.object().keys({
.required(), .required(),
email: Joi.string().trim().email().required(), email: Joi.string().trim().email().required(),
password: Joi.string().trim().min(8).max(128).required(), password: Joi.string().trim().min(8).max(128).required(),
confirm_password: Joi.string().trim().min(8).max(128).required() confirm_password: Joi.string().trim().min(8).max(128).required(),
}); });
module.exports = { module.exports = {
loginSchema, loginSchema,
registerSchema registerSchema,
}; };

View file

@ -13,10 +13,10 @@ const logger = pino({
'env.JWT_SECRET', 'env.JWT_SECRET',
'env.JWT_SECRET_DEV', 'env.JWT_SECRET_DEV',
'env.JWT_SECRET_PROD', 'env.JWT_SECRET_PROD',
'newUser.password' 'newUser.password',
], // See example to filter object class instances ], // See example to filter object class instances
censor: '***' // Redaction character censor: '***', // Redaction character
} },
}); });
// Sanitize outside the logger paths. This is useful for sanitizing variables directly with Regex and patterns. // Sanitize outside the logger paths. This is useful for sanitizing variables directly with Regex and patterns.
@ -33,7 +33,7 @@ const redactPatterns = [
/authorization[-_]?login[-_]?hint/i, /authorization[-_]?login[-_]?hint/i,
/authorization[-_]?acr[-_]?values/i, /authorization[-_]?acr[-_]?values/i,
/authorization[-_]?response[-_]?mode/i, /authorization[-_]?response[-_]?mode/i,
/authorization[-_]?nonce/i /authorization[-_]?nonce/i,
]; ];
/* /*
@ -58,7 +58,7 @@ const levels = {
INFO: 30, INFO: 30,
WARN: 40, WARN: 40,
ERROR: 50, ERROR: 50,
FATAL: 60 FATAL: 60,
}; };
let level = levels.INFO; let level = levels.INFO;
@ -121,6 +121,6 @@ module.exports = {
if (level < levels.DEBUG) return next(); if (level < levels.DEBUG) return next();
logger.debug({ query: req.query, body: req.body }, `Hit URL ${req.url} with following`); logger.debug({ query: req.query, body: req.body }, `Hit URL ${req.url} with following`);
return next(); return next();
} },
} },
}; };

View file

@ -1,6 +1,6 @@
async function abortMessage(req, res, abortControllers) { async function abortMessage(req, res, abortControllers) {
const { abortKey } = req.body; const { abortKey } = req.body;
console.log(`req.body`, req.body); console.log('req.body', req.body);
if (!abortControllers.has(abortKey)) { if (!abortControllers.has(abortKey)) {
return res.status(404).send('Request not found'); return res.status(404).send('Request not found');
} }

View file

@ -5,7 +5,7 @@ const genAzureEndpoint = ({ azureOpenAIApiInstanceName, azureOpenAIApiDeployment
const genAzureChatCompletion = ({ const genAzureChatCompletion = ({
azureOpenAIApiInstanceName, azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName, azureOpenAIApiDeploymentName,
azureOpenAIApiVersion azureOpenAIApiVersion,
}) => { }) => {
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}/chat/completions?api-version=${azureOpenAIApiVersion}`; return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}/chat/completions?api-version=${azureOpenAIApiVersion}`;
} }
@ -15,7 +15,7 @@ const getAzureCredentials = () => {
azureOpenAIApiKey: process.env.AZURE_API_KEY ?? process.env.AZURE_OPENAI_API_KEY, azureOpenAIApiKey: process.env.AZURE_API_KEY ?? process.env.AZURE_OPENAI_API_KEY,
azureOpenAIApiInstanceName: process.env.AZURE_OPENAI_API_INSTANCE_NAME, azureOpenAIApiInstanceName: process.env.AZURE_OPENAI_API_INSTANCE_NAME,
azureOpenAIApiDeploymentName: process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME, azureOpenAIApiDeploymentName: process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME,
azureOpenAIApiVersion: process.env.AZURE_OPENAI_API_VERSION azureOpenAIApiVersion: process.env.AZURE_OPENAI_API_VERSION,
} }
} }

View file

@ -2,7 +2,7 @@ const levels = {
NONE: 0, NONE: 0,
LOW: 1, LOW: 1,
MEDIUM: 2, MEDIUM: 2,
HIGH: 3 HIGH: 3,
}; };
let level = levels.HIGH; let level = levels.HIGH;
@ -41,6 +41,6 @@ module.exports = {
console.log('Body:', req.body); console.log('Body:', req.body);
console.groupEnd(); console.groupEnd();
return next(); return next();
} },
} },
}; };

View file

@ -10,5 +10,5 @@ module.exports = {
maxTokensMap, maxTokensMap,
tiktokenModels, tiktokenModels,
sendEmail, sendEmail,
abortMessage abortMessage,
} }

Some files were not shown because too many files have changed in this diff Show more