feat(GPT/Anthropic): Continue Regenerating & Generation Buttons (#808)

* feat(useMessageHandler.js/ts): Refactor and add features to handle user messages, support multiple endpoints/models, generate placeholder responses, regeneration, and stopGeneration function

fix(conversation.ts, buildTree.ts): Import TMessage type, handle null parentMessageId

feat(schemas.ts): Update and add schemas for various AI services, add default values, optional fields, and endpoint-to-schema mapping, create parseConvo function

chore(useMessageHandler.js, schemas.ts): Remove unused imports, variables, and chatGPT enum

* wip: add generation buttons

* refactor(cleanupPreset.ts): simplify cleanupPreset function
refactor(getDefaultConversation.js): remove unused code and simplify getDefaultConversation function

feat(utils): add getDefaultConversation function

This commit adds a new utility function called `getDefaultConversation` to the `client/src/utils/getDefaultConversation.ts` file. This function is responsible for generating a default conversation object based on the provided parameters.

The `getDefaultConversation` function takes in an object with the following properties:
- `conversation`: The conversation object to be used as a base.
- `endpointsConfig`: The configuration object containing information about the available endpoints.
- `preset`: An optional preset object that can be used to override the default behavior.

The function first tries to determine the target endpoint based on the preset object. If a valid endpoint is found, it is used as the target endpoint. If not, the function tries to retrieve the last conversation setup from the local storage and uses its endpoint if it is valid. If neither the preset nor the local storage contains a valid endpoint, the function falls back to a default endpoint.

Once the target endpoint is determined,

* fix(utils): remove console.error statement in buildDefaultConversation function
fix(schemas): add default values for catch blocks in openAISchema, googleSchema, bingAISchema, anthropicSchema, chatGPTBrowserSchema, and gptPluginsSchema

* fix: endpoint not changing on change of preset from other endpoint, wip: refactor

* refactor: preset items to TSX

* refactor: convert resetConvo to TS

* refactor(getDefaultConversation.ts): move defaultEndpoints array to the top of the file for better readability
refactor(getDefaultConversation.ts): extract getDefaultEndpoint function for better code organization and reusability

* feat(svg): add ContinueIcon component
feat(svg): add RegenerateIcon component
feat(svg): add ContinueIcon and RegenerateIcon components to index.ts

* feat(Button.tsx): add onClick and className props to Button component
feat(GenerationButtons.tsx): add logic to display Regenerate or StopGenerating button based on isSubmitting and messages
feat(Regenerate.tsx): create Regenerate component with RegenerateIcon and handleRegenerate function
feat(StopGenerating.tsx): create StopGenerating component with StopGeneratingIcon and handleStopGenerating function

* fix(TextChat.jsx): reorder imports and variables for better readability
fix(TextChat.jsx): fix typo in condition for isNotAppendable variable
fix(TextChat.jsx): remove unused handleStopGenerating function
fix(ContinueIcon.tsx): remove unnecessary closing tags for polygon elements
fix(useMessageHandler.ts): add missing type annotations for handleStopGenerating and handleRegenerate functions
fix(useMessageHandler.ts): remove unused variables in return statement

* fix(getDefaultConversation.ts): refactor code to use getLocalStorageItems function
feat(getLocalStorageItems.ts): add utility function to retrieve items from local storage

* fix(OpenAIClient.js): add support for streaming result in sendCompletion method
feat(OpenAIClient.js): add finish_reason metadata to opts in sendCompletion method
feat(Message.js): add finish_reason field to Message model
feat(messageSchema.js): add finish_reason field to messageSchema
feat(openAI.js): parse chatGptLabel and promptPrefix from req.body and pass rest of the modelOptions to endpointOption
feat(openAI.js): add addMetadata function to store metadata in ask function
feat(openAI.js): add metadata to response if available
feat(schemas.ts): add finish_reason field to tMessageSchema

* feat(types.ts): add TOnClick and TGenButtonProps types for button components
feat(Continue.tsx): create Continue component for generating button
feat(GenerationButtons.tsx): update GenerationButtons component to use Continue component
feat(Regenerate.tsx): create Regenerate component for regenerating button
feat(Stop.tsx): create Stop component for stop generating button

* feat(MessageHandler.jsx): add MessageHandler component to handle messages and conversations
fix(Root.jsx): fix import paths for Nav and MessageHandler components

* feat(useMessageHandler.ts): add support for generation parameter in ask function
feat(useMessageHandler.ts): add support for isEdited parameter in ask function
feat(useMessageHandler.ts): add support for continueGeneration function
fix(createPayload.ts): replace endpoint URL when isEdited parameter is true

* chore(client): set skipLibCheck to true in tsconfig.json

* fix(useMessageHandler.ts): remove unused clientId variable
fix(schemas.ts): make clientId field in tMessageSchema nullable and optional

* wip: edit route for continue generation

* refactor(api): move handlers to root of routes dir

* fix(useMessageHandler.ts): initialize currentMessages to an empty array if messages is null
fix(useMessageHandler.ts): update initialResponse text to use responseText variable
fix(useMessageHandler.ts): update setMessages logic for isRegenerate case
fix(MessageHandler.jsx): update setMessages logic for cancelHandler, createdHandler, and finalHandler

* fix(schemas.ts): make createdAt and updatedAt fields optional and set default values using new Date().toISOString()
fix(schemas.ts): change type annotation of TMessage from infer to input

* refactor(useMessageHandler.ts): rename AskProps type to TAskProps
refactor(useMessageHandler.ts): remove generation property from ask function arguments
refactor(useMessageHandler.ts): use nullish coalescing operator (??) instead of logical OR (||)
refactor(useMessageHandler.ts): pass the responseMessageId to message prop of submission

* fix(BaseClient.js): use nullish coalescing operator (??) instead of logical OR (||) for default values

* fix(BaseClient.js): fix responseMessageId assignment in handleStartMethods method
feat(BaseClient.js): add support for isEdited flag in sendMessage method
feat(BaseClient.js): add generation to responseMessage text in sendMessage method

* fix(openAI.js): remove unused imports and commented out code
feat(openAI.js): add support for generation parameter in request body
fix(openAI.js): remove console.log statement
fix(openAI.js): remove unused variables and parameters
fix(openAI.js): update response text in case of error
fix(openAI.js): handle error and abort message in case of error
fix(handlers.js): add generation parameter to createOnProgress function
fix(useMessageHandler.ts): update responseText variable to use generation parameter

* refactor(api/middleware): move inside server dir

* refactor: add endpoint specific, modular functions to build options and initialize clients, create server/utils, move middleware, separate utils into api general utils and server specific utils

* fix(abortMiddleware.js): import getConvo and getConvoTitle functions from models
feat(abortMiddleware.js): add abortAsk function to abortController to handle aborting of requests
fix(openAI.js): import buildOptions and initializeClient functions from endpoints/openAI
refactor(openAI.js): use getAbortData function to get data for abortAsk function

* refactor: move endpoint specific logic to an endpoints dir

* refactor(PluginService.js): fix import path for encrypt and decrypt functions in PluginService.js

* feat(openAI): add new endpoint for adding a title to a conversation

- Added a new file `addTitle.js` in the `api/server/routes/endpoints/openAI` directory.
- The `addTitle.js` file exports a function `addTitle` that takes in request parameters and performs the following actions:
  - If the `parentMessageId` is `'00000000-0000-0000-0000-000000000000'` and `newConvo` is true, it proceeds with the following steps:
    - Calls the `titleConvo` function from the `titleConvo` module, passing in the necessary parameters.
    - Calls the `saveConvo` function from the `saveConvo` module, passing in the user ID and conversation details.
- Updated the `index.js` file in the `api/server/routes/endpoints/openAI` directory to export the `addTitle` function.
- This change adds

* fix(abortMiddleware.js): remove console.log statement
refactor(gptPlugins.js): update imports and function parameters
feat(gptPlugins.js): add support for abortController and getAbortData
refactor(openAI.js): update imports and function parameters
feat(openAI.js): add support for abortController and getAbortData

fix(openAI.js): refactor code to use modularized functions and middleware
fix(buildOptions.js): refactor code to use destructuring and update variable names

* refactor(askChatGPTBrowser.js, bingAI.js, google.js): remove duplicate code for setting response headers
feat(askChatGPTBrowser.js, bingAI.js, google.js): add setHeaders middleware to set response headers

* feat(middleware): validateEndpoint, refactor buildOption to only be concerned of endpointOption

* fix(abortMiddleware.js): add 'finish_reason' property with value 'incomplete' to responseMessage object
fix(abortMessage.js): remove console.log statement for aborted message
fix(handlers.js): modify tokens assignment to handle empty generation string and trailing space

* fix(BaseClient.js): import addSpaceIfNeeded function from server/utils
fix(BaseClient.js): add space before generation in text property
fix(index.js): remove getCitations and citeText exports
feat(buildEndpointOption.js): add buildEndpointOption middleware
fix(index.js): import buildEndpointOption middleware
fix(anthropic.js): remove buildOptions function and use endpointOption from req.body
fix(gptPlugins.js): remove buildOptions function and use endpointOption from req.body
fix(openAI.js): remove buildOptions function and use endpointOption from req.body

feat(utils): add citations.js and handleText.js modules
fix(utils): fix import statements in index.js module

* refactor(gptPlugins.js): use getResponseSender function from librechat-data-provider

* feat(gptPlugins): complete 'continue generating'

* wip: anthropic continue regen

* feat(middleware): add validateRegistration middleware

A new middleware function called `validateRegistration` has been added to the list of exported middleware functions in `index.js`. This middleware is responsible for validating registration data before allowing the registration process to proceed.

* feat(Anthropic): complete continue regen

* chore: add librechat-data-provider to api/package.json

* fix(ci): backend-review will mock meilisearch, also installs data-provider as now needed

* chore(ci): remove unneeded SEARCH env var

* style(GenerationButtons): make text shorter for sake of space economy, even though this diverges from chat.openai.com

* style(GenerationButtons/ScrollToBottom): adjust visibility/position based on screen size

* chore(client): 'Editting' typo

* feat(GenerationButtons.tsx): add support for endpoint prop in GenerationButtons component
feat(OptionsBar.tsx): pass endpoint prop to GenerationButtons component
feat(useGenerations.ts): create useGenerations hook to handle generation logic
fix(schemas.ts): add searchResult field to tMessageSchema

* refactor(HoverButtons): convert to TSX and utilize new useGenerations hook

* fix(abortMiddleware): handle error with res headers set, or abortController not found, to ensure proper API error is sent to the client, chore(BaseClient): remove console log for onStart message meant for debugging

* refactor(api): remove librechat-data-provider dep for now as it complicates deployed docker build stage, re-use code in CJS, located in server/endpoints/schemas

* chore: remove console.logs from test files

* ci: add backend tests for AnthropicClient, focusing on new buildMessages logic

* refactor(FakeClient): use actual BaseClient sendMessage method for testing

* test(BaseClient.test.js): add test for loading chat history
test(BaseClient.test.js): add test for sendMessage logic with isEdited flag

* fix(buildEndpointOption.js): add support for azureOpenAI in buildFunction object
wip(endpoints.js): fetch Azure models from Azure OpenAI API if opts.azure is true

* fix(Button.tsx): add data-testid attribute to button component
fix(SelectDropDown.tsx): add data-testid attribute to Listbox.Button component
fix(messages.spec.ts): add waitForServerStream function to consolidate logic for awaiting the server response
feat(messages.spec.ts): add test for stopping and continuing message and improve browser/page context order and closing

* refactor(onProgress): speed up time to save initial message for editable routes

* chore: disable AI message editing (for now), was accidentally allowed

* refactor: ensure continue is only supported for latest message style: improve styling in dark mode and across all hover buttons/icons, including making edit icon for AI invisible (for now)

* fix: add test id to generation buttons so they never resolve to 2+ items

* chore(package.json): add 'packages/' to the list of ignored directories
chore(data-provider/package.json): bump version to 0.1.5
This commit is contained in:
Danny Avila 2023-08-17 12:50:05 -04:00 committed by GitHub
parent ae5b7d3d53
commit afd43afb60
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
113 changed files with 3023 additions and 1543 deletions

View file

@ -1,10 +1,5 @@
name: Backend Unit Tests
on:
# push:
# branches:
# - main
# - dev
# - release/*
pull_request:
branches:
- main
@ -23,6 +18,7 @@ jobs:
JWT_SECRET: ${{ secrets.JWT_SECRET }}
CREDS_KEY: ${{ secrets.CREDS_KEY }}
CREDS_IV: ${{ secrets.CREDS_IV }}
NODE_ENV: ci
steps:
- uses: actions/checkout@v2
- name: Use Node.js 20.x
@ -34,8 +30,8 @@ jobs:
- name: Install dependencies
run: npm ci
# - name: Install Linux X64 Sharp
# run: npm install --platform=linux --arch=x64 --verbose sharp
- name: Install Data Provider
run: npm run build:data-provider
- name: Run unit tests
run: cd api && npm run test:ci

View file

@ -1,4 +1,3 @@
const Keyv = require('keyv');
// const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient');
const {
@ -15,8 +14,6 @@ const tokenizersCache = {};
class AnthropicClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}) {
super(apiKey, options, cacheOptions);
cacheOptions.namespace = cacheOptions.namespace || 'anthropic';
this.conversationsCache = new Keyv(cacheOptions);
this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY;
this.sender = 'Anthropic';
this.userLabel = HUMAN_PROMPT;
@ -107,6 +104,23 @@ class AnthropicClient extends BaseClient {
content: message?.content ?? message.text,
}));
let lastAuthor = '';
let groupedMessages = [];
for (let message of formattedMessages) {
// If last author is not same as current author, add to new group
if (lastAuthor !== message.author) {
groupedMessages.push({
author: message.author,
content: [message.content],
});
lastAuthor = message.author;
// If same author, append content to the last group
} else {
groupedMessages[groupedMessages.length - 1].content.push(message.content);
}
}
let identityPrefix = '';
if (this.options.userLabel) {
identityPrefix = `\nHuman's name: ${this.options.userLabel}`;
@ -129,8 +143,12 @@ class AnthropicClient extends BaseClient {
promptPrefix = `${identityPrefix}${promptPrefix}`;
}
const promptSuffix = `${promptPrefix}${this.assistantLabel}\n`; // Prompt AI to respond.
let currentTokenCount = this.getTokenCount(promptSuffix);
// Prompt AI to respond, empty if last message was from AI
let isEdited = lastAuthor === this.assistantLabel;
const promptSuffix = isEdited ? '' : `${promptPrefix}${this.assistantLabel}\n`;
let currentTokenCount = isEdited
? this.getTokenCount(promptPrefix)
: this.getTokenCount(promptSuffix);
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
@ -148,10 +166,13 @@ class AnthropicClient extends BaseClient {
};
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && formattedMessages.length > 0) {
const message = formattedMessages.pop();
if (currentTokenCount < maxTokenCount && groupedMessages.length > 0) {
const message = groupedMessages.pop();
const isCreatedByUser = message.author === this.userLabel;
const messageString = `${message.author}\n${message.content}${this.endToken}\n`;
// Use promptPrefix if message is edited assistant'
const messagePrefix =
isCreatedByUser || !isEdited ? message.author : `${promptPrefix}${message.author}`;
const messageString = `${messagePrefix}\n${message.content}${this.endToken}\n`;
let newPromptBody = `${messageString}${promptBody}`;
context.unshift(message);
@ -182,6 +203,12 @@ class AnthropicClient extends BaseClient {
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// Switch off isEdited after using it for the first time
if (isEdited) {
isEdited = false;
}
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody();
@ -197,7 +224,8 @@ class AnthropicClient extends BaseClient {
context.shift();
}
const prompt = `${promptBody}${promptSuffix}`;
let prompt = `${promptBody}${promptSuffix}`;
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;

View file

@ -5,11 +5,12 @@ const { ChatOpenAI } = require('langchain/chat_models/openai');
const { loadSummarizationChain } = require('langchain/chains');
const { refinePrompt } = require('./prompts/refinePrompt');
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models');
const { addSpaceIfNeeded } = require('../../server/utils');
class BaseClient {
constructor(apiKey, options = {}) {
this.apiKey = apiKey;
this.sender = options.sender || 'AI';
this.sender = options.sender ?? 'AI';
this.contextStrategy = null;
this.currentDateString = new Date().toLocaleDateString('en-us', {
year: 'numeric',
@ -51,18 +52,20 @@ class BaseClient {
if (opts && typeof opts === 'object') {
this.setOptions(opts);
}
const user = opts.user || null;
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000';
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
const responseMessageId = crypto.randomUUID();
const user = opts.user ?? null;
const conversationId = opts.conversationId ?? crypto.randomUUID();
const parentMessageId = opts.parentMessageId ?? '00000000-0000-0000-0000-000000000000';
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
const responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
const saveOptions = this.getSaveOptions();
this.abortController = opts.abortController || new AbortController();
this.currentMessages = (await this.loadHistory(conversationId, parentMessageId)) ?? [];
const head = opts.isEdited ? responseMessageId : parentMessageId;
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
this.abortController = opts.abortController ?? new AbortController();
return {
...opts,
user,
head,
conversationId,
parentMessageId,
userMessageId,
@ -72,7 +75,7 @@ class BaseClient {
}
createUserMessage({ messageId, parentMessageId, conversationId, text }) {
const userMessage = {
return {
messageId,
parentMessageId,
conversationId,
@ -80,19 +83,27 @@ class BaseClient {
text,
isCreatedByUser: true,
};
return userMessage;
}
async handleStartMethods(message, opts) {
const { user, conversationId, parentMessageId, userMessageId, responseMessageId, saveOptions } =
await this.setMessageOptions(opts);
const userMessage = this.createUserMessage({
messageId: userMessageId,
parentMessageId,
const {
user,
head,
conversationId,
text: message,
});
parentMessageId,
userMessageId,
responseMessageId,
saveOptions,
} = await this.setMessageOptions(opts);
const userMessage = opts.isEdited
? this.currentMessages[this.currentMessages.length - 2]
: this.createUserMessage({
messageId: userMessageId,
parentMessageId,
conversationId,
text: message,
});
if (typeof opts?.getIds === 'function') {
opts.getIds({
@ -109,6 +120,7 @@ class BaseClient {
return {
...opts,
user,
head,
conversationId,
responseMessageId,
saveOptions,
@ -373,7 +385,7 @@ class BaseClient {
if (this.options.debug) {
console.debug('<-------------------------PAYLOAD/TOKEN COUNT MAP------------------------->');
console.debug('Payload:', payload);
// console.debug('Payload:', payload);
console.debug('Token Count Map:', tokenCountMap);
console.debug('Prompt Tokens', promptTokens, remainingContextTokens, this.maxContextTokens);
}
@ -382,13 +394,16 @@ class BaseClient {
}
async sendMessage(message, opts = {}) {
const { user, conversationId, responseMessageId, saveOptions, userMessage } =
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
await this.handleStartMethods(message, opts);
this.user = user;
// It's not necessary to push to currentMessages
// depending on subclass implementation of handling messages
this.currentMessages.push(userMessage);
// When this is an edit, all messages are already in currentMessages, both user and response
if (!isEdited) {
this.currentMessages.push(userMessage);
}
let {
prompt: payload,
@ -398,13 +413,13 @@ class BaseClient {
this.currentMessages,
// When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
// this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
userMessage.messageId,
isEdited ? head : userMessage.messageId,
this.getBuildMessagesOptions(opts),
);
if (this.options.debug) {
console.debug('payload');
console.debug(payload);
// console.debug(payload);
}
if (tokenCountMap) {
@ -423,7 +438,11 @@ class BaseClient {
this.handleTokenCountMap(tokenCountMap);
}
await this.saveMessageToDatabase(userMessage, saveOptions, user);
if (!isEdited) {
await this.saveMessageToDatabase(userMessage, saveOptions, user);
}
const generation = isEdited ? this.currentMessages[this.currentMessages.length - 1].text : '';
const responseMessage = {
messageId: responseMessageId,
conversationId,
@ -431,7 +450,7 @@ class BaseClient {
isCreatedByUser: false,
model: this.modelOptions.model,
sender: this.sender,
text: await this.sendCompletion(payload, opts),
text: addSpaceIfNeeded(generation) + (await this.sendCompletion(payload, opts)),
promptTokens,
};
@ -453,7 +472,7 @@ class BaseClient {
console.debug('Loading history for conversation', conversationId, parentMessageId);
}
const messages = (await getMessages({ conversationId })) || [];
const messages = (await getMessages({ conversationId })) ?? [];
if (messages.length === 0) {
return [];

View file

@ -314,6 +314,7 @@ class OpenAIClient extends BaseClient {
async sendCompletion(payload, opts = {}) {
let reply = '';
let result = null;
let streamResult = null;
if (typeof opts.onProgress === 'function') {
await this.getCompletion(
payload,
@ -321,6 +322,10 @@ class OpenAIClient extends BaseClient {
if (progressMessage === '[DONE]') {
return;
}
if (progressMessage.choices) {
streamResult = progressMessage;
}
const token = this.isChatCompletion
? progressMessage.choices?.[0]?.delta?.content
: progressMessage.choices?.[0]?.text;
@ -355,6 +360,10 @@ class OpenAIClient extends BaseClient {
}
}
if (streamResult && typeof opts.addMetadata === 'function') {
const { finish_reason } = streamResult.choices[0];
opts.addMetadata({ finish_reason });
}
return reply.trim();
}

View file

@ -345,7 +345,8 @@ Only respond with your conversational reply to the following User Message:
}
async sendMessage(message, opts = {}) {
const completionMode = this.options.tools.length === 0;
// If a message is edited, no tools can be used.
const completionMode = this.options.tools.length === 0 || opts.isEdited;
if (completionMode) {
this.setOptions(opts);
return super.sendMessage(message, opts);

View file

@ -0,0 +1,139 @@
const AnthropicClient = require('../AnthropicClient');
const HUMAN_PROMPT = '\n\nHuman:';
const AI_PROMPT = '\n\nAssistant:';
describe('AnthropicClient', () => {
let client;
const model = 'claude-2';
const parentMessageId = '1';
const messages = [
{ role: 'user', isCreatedByUser: true, text: 'Hello', messageId: parentMessageId },
{ role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId },
{
role: 'user',
isCreatedByUser: true,
text: 'What\'s up',
messageId: '3',
parentMessageId: '2',
},
];
beforeEach(() => {
const options = {
modelOptions: {
model,
temperature: 0.7,
},
};
client = new AnthropicClient('test-api-key');
client.setOptions(options);
});
describe('setOptions', () => {
it('should set the options correctly', () => {
expect(client.apiKey).toBe('test-api-key');
expect(client.modelOptions.model).toBe(model);
expect(client.modelOptions.temperature).toBe(0.7);
});
});
describe('getSaveOptions', () => {
it('should return the correct save options', () => {
const options = client.getSaveOptions();
expect(options).toHaveProperty('modelLabel');
expect(options).toHaveProperty('promptPrefix');
});
});
describe('buildMessages', () => {
it('should handle promptPrefix from options when promptPrefix argument is not provided', async () => {
client.options.promptPrefix = 'Test Prefix from options';
const result = await client.buildMessages(messages, parentMessageId);
const { prompt } = result;
expect(prompt).toContain('Test Prefix from options');
});
it('should build messages correctly for chat completion', async () => {
const result = await client.buildMessages(messages, '2');
expect(result).toHaveProperty('prompt');
expect(result.prompt).toContain(HUMAN_PROMPT);
expect(result.prompt).toContain('Hello');
expect(result.prompt).toContain(AI_PROMPT);
expect(result.prompt).toContain('Hi');
});
it('should group messages by the same author', async () => {
const groupedMessages = messages.map((m) => ({ ...m, isCreatedByUser: true, role: 'user' }));
const result = await client.buildMessages(groupedMessages, '3');
expect(result.context).toHaveLength(1);
// Check that HUMAN_PROMPT appears only once in the prompt
const matches = result.prompt.match(new RegExp(HUMAN_PROMPT, 'g'));
expect(matches).toHaveLength(1);
groupedMessages.push({
role: 'assistant',
isCreatedByUser: false,
text: 'I heard you the first time',
messageId: '4',
parentMessageId: '3',
});
const result2 = await client.buildMessages(groupedMessages, '4');
expect(result2.context).toHaveLength(2);
// Check that HUMAN_PROMPT appears only once in the prompt
const human_matches = result2.prompt.match(new RegExp(HUMAN_PROMPT, 'g'));
const ai_matches = result2.prompt.match(new RegExp(AI_PROMPT, 'g'));
expect(human_matches).toHaveLength(1);
expect(ai_matches).toHaveLength(1);
});
it('should handle isEdited condition', async () => {
const editedMessages = [
{ role: 'user', isCreatedByUser: true, text: 'Hello', messageId: '1' },
{ role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId },
];
const trimmedLabel = AI_PROMPT.trim();
const result = await client.buildMessages(editedMessages, '2');
expect(result.prompt.trim().endsWith(trimmedLabel)).toBeFalsy();
// Add a human message at the end to test the opposite
editedMessages.push({
role: 'user',
isCreatedByUser: true,
text: 'Hi again',
messageId: '3',
parentMessageId: '2',
});
const result2 = await client.buildMessages(editedMessages, '3');
expect(result2.prompt.trim().endsWith(trimmedLabel)).toBeTruthy();
});
it('should build messages correctly with a promptPrefix', async () => {
const promptPrefix = 'Test Prefix';
client.options.promptPrefix = promptPrefix;
const result = await client.buildMessages(messages, parentMessageId);
const { prompt } = result;
expect(prompt).toBeDefined();
expect(prompt).toContain(promptPrefix);
const textAfterPrefix = prompt.split(promptPrefix)[1];
expect(textAfterPrefix).toContain(AI_PROMPT);
const editedMessages = messages.slice(0, -1);
const result2 = await client.buildMessages(editedMessages, parentMessageId);
const textAfterPrefix2 = result2.prompt.split(promptPrefix)[1];
expect(textAfterPrefix2).toContain(AI_PROMPT);
});
it('should handle identityPrefix from options', async () => {
client.options.userLabel = 'John';
client.options.modelLabel = 'Claude-2';
const result = await client.buildMessages(messages, parentMessageId);
const { prompt } = result;
expect(prompt).toContain('Human\'s name: John');
expect(prompt).toContain('You are Claude-2');
});
});
});

View file

@ -45,6 +45,18 @@ const fakeMessages = [];
const userMessage = 'Hello, ChatGPT!';
const apiKey = 'fake-api-key';
const messageHistory = [
{ role: 'user', isCreatedByUser: true, text: 'Hello', messageId: '1' },
{ role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId: '1' },
{
role: 'user',
isCreatedByUser: true,
text: 'What\'s up',
messageId: '3',
parentMessageId: '2',
},
];
describe('BaseClient', () => {
let TestClient;
const options = {
@ -277,9 +289,54 @@ describe('BaseClient', () => {
});
test('should return chat history', async () => {
const chatMessages = await TestClient.loadHistory(conversationId, parentMessageId);
expect(TestClient.currentMessages).toHaveLength(4);
expect(chatMessages[0].text).toEqual(userMessage);
TestClient = initializeFakeClient(apiKey, options, messageHistory);
const chatMessages = await TestClient.loadHistory(conversationId, '2');
expect(TestClient.currentMessages).toHaveLength(2);
expect(chatMessages[0].text).toEqual('Hello');
const chatMessages2 = await TestClient.loadHistory(conversationId, '3');
expect(TestClient.currentMessages).toHaveLength(3);
expect(chatMessages2[chatMessages2.length - 1].text).toEqual('What\'s up');
});
/* Most of the new sendMessage logic revolving around edited/continued AI messages
* can be summarized by the following test. The condition will load the entire history up to
* the message that is being edited, which will trigger the AI API to 'continue' the response.
* The 'userMessage' is only passed by convention and is not necessary for the generation.
*/
it('should not push userMessage to currentMessages when isEdited is true and vice versa', async () => {
const overrideParentMessageId = 'user-message-id';
const responseMessageId = 'response-message-id';
const newHistory = messageHistory.slice();
newHistory.push({
role: 'assistant',
isCreatedByUser: false,
text: 'test message',
messageId: responseMessageId,
parentMessageId: '3',
});
TestClient = initializeFakeClient(apiKey, options, newHistory);
const sendMessageOptions = {
isEdited: true,
overrideParentMessageId,
parentMessageId: '3',
responseMessageId,
};
await TestClient.sendMessage('test message', sendMessageOptions);
const currentMessages = TestClient.currentMessages;
expect(currentMessages[currentMessages.length - 1].messageId).not.toEqual(
overrideParentMessageId,
);
// Test the opposite case
sendMessageOptions.isEdited = false;
await TestClient.sendMessage('test message', sendMessageOptions);
const currentMessages2 = TestClient.currentMessages;
expect(currentMessages2[currentMessages2.length - 1].messageId).toEqual(
overrideParentMessageId,
);
});
test('setOptions is called with the correct arguments', async () => {

View file

@ -1,4 +1,3 @@
const crypto = require('crypto');
const BaseClient = require('../BaseClient');
const { maxTokensMap } = require('../../../utils');
@ -87,86 +86,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
return 'Mock response text';
});
TestClient.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
if (opts && typeof opts === 'object') {
TestClient.setOptions(opts);
}
const user = opts.user || null;
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000';
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
const saveOptions = TestClient.getSaveOptions();
this.pastMessages = await TestClient.loadHistory(
conversationId,
TestClient.options?.parentMessageId,
);
const userMessage = {
text: message,
sender: TestClient.sender,
isCreatedByUser: true,
messageId: userMessageId,
parentMessageId,
conversationId,
};
const response = {
sender: TestClient.sender,
text: 'Hello, User!',
isCreatedByUser: false,
messageId: crypto.randomUUID(),
parentMessageId: userMessage.messageId,
conversationId,
};
fakeMessages.push(userMessage);
fakeMessages.push(response);
if (typeof opts.getIds === 'function') {
opts.getIds({
userMessage,
conversationId,
responseMessageId: response.messageId,
});
}
if (typeof opts.onStart === 'function') {
opts.onStart(userMessage);
}
let { prompt: payload, tokenCountMap } = await TestClient.buildMessages(
this.currentMessages,
userMessage.messageId,
TestClient.getBuildMessagesOptions(opts),
);
if (tokenCountMap) {
payload = payload.map((message, i) => {
const { tokenCount, ...messageWithoutTokenCount } = message;
// userMessage is always the last one in the payload
if (i === payload.length - 1) {
userMessage.tokenCount = message.tokenCount;
console.debug(
`Token count for user message: ${tokenCount}`,
`Instruction Tokens: ${tokenCountMap.instructions || 'N/A'}`,
);
}
return messageWithoutTokenCount;
});
TestClient.handleTokenCountMap(tokenCountMap);
}
await TestClient.saveMessageToDatabase(userMessage, saveOptions, user);
response.text = await TestClient.sendCompletion(payload, opts);
if (tokenCountMap && TestClient.getTokenCountForResponse) {
response.tokenCount = TestClient.getTokenCountForResponse(response);
}
await TestClient.saveMessageToDatabase(response, saveOptions, user);
return response;
});
TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
const orderedMessages = TestClient.constructor.getMessagesForConversation(
messages,

View file

@ -1,5 +1,7 @@
const OpenAIClient = require('../OpenAIClient');
jest.mock('meilisearch');
describe('OpenAIClient', () => {
let client, client2;
const model = 'gpt-4';
@ -25,6 +27,9 @@ describe('OpenAIClient', () => {
content: 'Refined answer',
tokenCount: 30,
});
client.buildPrompt = jest
.fn()
.mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') });
client.constructor.freeAndResetAllEncoders();
});

View file

@ -111,7 +111,6 @@ describe('PluginsClient', () => {
});
const response = await TestAgent.sendMessage(userMessage);
console.log(response);
parentMessageId = response.messageId;
conversationId = response.conversationId;
expect(response).toEqual(expectedResult);

View file

@ -20,7 +20,6 @@ async function addOpenAPISpecs(availableTools) {
}
return availableTools;
} catch (error) {
console.log('addOpenAPISpecs error', error);
return availableTools;
}
}

View file

@ -83,7 +83,6 @@ describe('Tool Handlers', () => {
it('returns valid tools given input tools and user authentication', async () => {
const validTools = await validateTools(fakeUser._id, initialTools);
expect(validTools).toBeDefined();
console.log('validateTools: validTools', validTools);
expect(validTools.some((tool) => tool === pluginKey)).toBeTruthy();
expect(validTools.length).toBeGreaterThan(0);
});

View file

@ -3,15 +3,11 @@ const { askBing } = require('./bingai');
const clients = require('./clients');
const titleConvo = require('./titleConvo');
const titleConvoBing = require('./titleConvoBing');
const getCitations = require('../lib/parse/getCitations');
const citeText = require('../lib/parse/citeText');
module.exports = {
browserClient,
askBing,
titleConvo,
titleConvoBing,
getCitations,
citeText,
...clients,
};

View file

@ -1,18 +0,0 @@
// const regex = / \[\d+\..*?\]\(.*?\)/g;
const regex = / \[.*?]\(.*?\)/g;
const getCitations = (res) => {
const adaptiveCards = res.details.adaptiveCards;
const textBlocks = adaptiveCards && adaptiveCards[0].body;
if (!textBlocks) {
return '';
}
let links = textBlocks[textBlocks.length - 1]?.text.match(regex);
if (links?.length === 0 || !links) {
return '';
}
links = links.map((link) => link.trim());
return links.join('\n - ');
};
module.exports = getCitations;

View file

@ -14,6 +14,7 @@ module.exports = {
error,
unfinished,
cancelled,
finish_reason = null,
tokenCount = null,
plugin = null,
model = null,
@ -29,6 +30,7 @@ module.exports = {
sender,
text,
isCreatedByUser,
finish_reason,
error,
unfinished,
cancelled,

View file

@ -67,6 +67,9 @@ const messageSchema = mongoose.Schema(
type: Boolean,
default: false,
},
finish_reason: {
type: String,
},
_meiliIndex: {
type: Boolean,
required: false,

View file

@ -84,6 +84,7 @@ config.validate(); // Validate the config
app.use('/api/user', routes.user);
app.use('/api/search', routes.search);
app.use('/api/ask', routes.ask);
app.use('/api/edit', routes.edit);
app.use('/api/messages', routes.messages);
app.use('/api/convos', routes.convos);
app.use('/api/presets', routes.presets);

View file

@ -0,0 +1,2 @@
// abortControllers.js
module.exports = new Map();

View file

@ -0,0 +1,106 @@
const { saveMessage, getConvo, getConvoTitle } = require('../../models');
const { sendMessage, handleError } = require('../utils');
const abortControllers = require('./abortControllers');
async function abortMessage(req, res) {
const { abortKey } = req.body;
if (!abortControllers.has(abortKey) && !res.headersSent) {
return res.status(404).send('Request not found');
}
const { abortController } = abortControllers.get(abortKey);
const ret = await abortController.abortCompletion();
console.log('Aborted request', abortKey);
abortControllers.delete(abortKey);
res.send(JSON.stringify(ret));
}
const handleAbort = () => {
return async (req, res) => {
try {
return await abortMessage(req, res);
} catch (err) {
console.error(err);
}
};
};
const createAbortController = (res, req, endpointOption, getAbortData) => {
const abortController = new AbortController();
const onStart = (userMessage) => {
sendMessage(res, { message: userMessage, created: true });
abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption });
res.on('finish', function () {
abortControllers.delete(userMessage.conversationId);
});
};
abortController.abortCompletion = async function () {
abortController.abort();
const { conversationId, userMessage, ...responseData } = getAbortData();
const responseMessage = {
...responseData,
finish_reason: 'incomplete',
model: endpointOption.modelOptions.model,
unfinished: false,
cancelled: true,
error: false,
};
saveMessage(responseMessage);
return {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage,
};
};
return { abortController, onStart };
};
const handleAbortError = async (res, req, error, data) => {
console.error(error);
const { sender, conversationId, messageId, parentMessageId, partialText } = data;
const respondWithError = async () => {
const errorMessage = {
sender,
messageId,
conversationId,
parentMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message,
};
if (abortControllers.has(conversationId)) {
const { abortController } = abortControllers.get(conversationId);
abortController.abort();
abortControllers.delete(conversationId);
}
await saveMessage(errorMessage);
handleError(res, errorMessage);
};
if (partialText?.length > 2) {
try {
return await abortMessage(req, res);
} catch (err) {
return respondWithError();
}
} else {
return respondWithError();
}
};
module.exports = {
handleAbort,
createAbortController,
handleAbortError,
};

View file

@ -0,0 +1,20 @@
const openAI = require('../routes/endpoints/openAI');
const gptPlugins = require('../routes/endpoints/gptPlugins');
const anthropic = require('../routes/endpoints/anthropic');
const { parseConvo } = require('../routes/endpoints/schemas');
const buildFunction = {
openAI: openAI.buildOptions,
azureOpenAI: openAI.buildOptions,
gptPlugins: gptPlugins.buildOptions,
anthropic: anthropic.buildOptions,
};
function buildEndpointOption(req, res, next) {
const { endpoint } = req.body;
const parsedBody = parseConvo(endpoint, req.body);
req.body.endpointOption = buildFunction[endpoint](endpoint, parsedBody);
next();
}
module.exports = buildEndpointOption;

View file

@ -0,0 +1,17 @@
const abortMiddleware = require('./abortMiddleware');
const setHeaders = require('./setHeaders');
const requireJwtAuth = require('./requireJwtAuth');
const requireLocalAuth = require('./requireLocalAuth');
const validateEndpoint = require('./validateEndpoint');
const buildEndpointOption = require('./buildEndpointOption');
const validateRegistration = require('./validateRegistration');
module.exports = {
...abortMiddleware,
setHeaders,
requireJwtAuth,
requireLocalAuth,
validateEndpoint,
buildEndpointOption,
validateRegistration,
};

View file

@ -1,5 +1,5 @@
const passport = require('passport');
const DebugControl = require('../utils/debug.js');
const DebugControl = require('../../utils/debug.js');
function log({ title, parameters }) {
DebugControl.log.functionName(title);

View file

@ -0,0 +1,12 @@
function setHeaders(req, res, next) {
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no',
});
next();
}
module.exports = setHeaders;

View file

@ -0,0 +1,19 @@
const { handleError } = require('../utils');
function validateEndpoint(req, res, next) {
const { endpoint } = req.body;
if (!req.body.text || req.body.text.length === 0) {
return handleError(res, { text: 'Prompt empty or too short' });
}
const pathEndpoint = req.baseUrl.split('/')[3];
if (endpoint !== pathEndpoint) {
return handleError(res, { text: 'Illegal request: Endpoint mismatch' });
}
next();
}
module.exports = validateEndpoint;

View file

@ -1,72 +1,43 @@
const express = require('express');
const router = express.Router();
const crypto = require('crypto');
const { titleConvo, AnthropicClient } = require('../../../app');
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
const { abortMessage } = require('../../../utils');
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/anthropic');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress } = require('./handlers');
const { sendMessage, createOnProgress } = require('../../utils');
const abortControllers = new Map();
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', requireJwtAuth, async (req, res) => {
try {
return await abortMessage(req, res, abortControllers);
} catch (err) {
console.error(err);
}
});
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
router.post('/', requireJwtAuth, async (req, res) => {
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
if (text.length === 0) {
return handleError(res, { text: 'Prompt empty or too short' });
}
if (endpoint !== 'anthropic') {
return handleError(res, { text: 'Illegal request' });
}
const endpointOption = {
promptPrefix: req.body?.promptPrefix ?? null,
modelLabel: req.body?.modelLabel ?? null,
token: req.body?.token ?? null,
modelOptions: {
model: req.body?.model ?? 'claude-1',
temperature: req.body?.temperature ?? 1,
maxOutputTokens: req.body?.maxOutputTokens ?? 1024,
topP: req.body?.topP ?? 0.7,
topK: req.body?.topK ?? 5,
},
};
const conversationId = oldConversationId || crypto.randomUUID();
return await ask({
text,
endpointOption,
conversationId,
parentMessageId,
req,
res,
});
});
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no',
});
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
const { overrideParentMessageId = null } = req.body;
try {
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = data.userMessage.messageId;
@ -79,116 +50,95 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: 'Anthropic',
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const abortController = new AbortController();
abortController.abortAsk = async function () {
this.abort();
const responseMessage = {
messageId: responseMessageId,
sender: 'Anthropic',
try {
const getAbortData = () => ({
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
model: endpointOption.modelOptions.model,
unfinished: false,
cancelled: true,
error: false,
};
userMessage,
});
saveMessage(responseMessage);
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
return {
const { client } = initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
getIds,
debug: false,
user: req.user.id,
conversationId,
parentMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
onStart,
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
await saveConvo(req.user.id, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'anthropic',
});
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage,
};
};
responseMessage: response,
});
res.end();
const onStart = (userMessage) => {
sendMessage(res, { message: userMessage, created: true });
abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption });
};
const client = new AnthropicClient(endpointOption.token);
let response = await client.sendMessage(text, {
getIds,
debug: false,
user: req.user.id,
conversationId,
parentMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
onStart,
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
await saveConvo(req.user.id, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'anthropic',
});
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (parentMessageId == '00000000-0000-0000-0000-000000000000') {
const title = await titleConvo({ text, response });
await saveConvo(req.user.id, {
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
title,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId,
});
}
} catch (error) {
console.error(error);
const errorMessage = {
messageId: responseMessageId,
sender: 'Anthropic',
conversationId,
parentMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message,
};
await saveMessage(errorMessage);
handleError(res, errorMessage);
}
};
},
);
module.exports = router;

View file

@ -1,13 +1,12 @@
const express = require('express');
const crypto = require('crypto');
const router = express.Router();
// const { getChatGPTBrowserModels } = require('../endpoints');
const { browserClient } = require('../../../app/');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils');
const { requireJwtAuth, setHeaders } = require('../../middleware');
router.post('/', requireJwtAuth, async (req, res) => {
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
const {
endpoint,
text,
@ -86,15 +85,6 @@ const ask = async ({
}) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
const userId = req.user.id;
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no',
});
let responseMessageId = crypto.randomUUID();
let getPartialMessage = null;
try {

View file

@ -3,10 +3,10 @@ const crypto = require('crypto');
const router = express.Router();
const { titleConvoBing, askBing } = require('../../../app');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils');
const { requireJwtAuth, setHeaders } = require('../../middleware');
router.post('/', requireJwtAuth, async (req, res) => {
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
const {
endpoint,
text,
@ -103,14 +103,6 @@ const ask = async ({
let responseMessageId = crypto.randomUUID();
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no',
});
if (preSendRequest) {
sendMessage(res, { message: userMessage, created: true });
}

View file

@ -2,12 +2,11 @@ const express = require('express');
const router = express.Router();
const crypto = require('crypto');
const { titleConvo, GoogleClient } = require('../../../app');
// const GoogleClient = require('../../../app/google/GoogleClient');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress } = require('./handlers');
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
const { handleError, sendMessage, createOnProgress } = require('../../utils');
const { requireJwtAuth, setHeaders } = require('../../middleware');
router.post('/', requireJwtAuth, async (req, res) => {
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
if (text.length === 0) {
return handleError(res, { text: 'Prompt empty or too short' });
@ -50,13 +49,6 @@ router.post('/', requireJwtAuth, async (req, res) => {
});
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no',
});
let userMessage;
let userMessageId;
let responseMessageId;

View file

@ -1,112 +1,56 @@
const express = require('express');
const router = express.Router();
const { titleConvo, validateTools, PluginsClient } = require('../../../app');
const { abortMessage, getAzureCredentials } = require('../../../utils');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { getResponseSender } = require('../endpoints/schemas');
const { validateTools } = require('../../../app');
const { addTitle } = require('../endpoints/openAI');
const { initializeClient } = require('../endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils');
const {
handleError,
sendMessage,
createOnProgress,
formatSteps,
formatAction,
} = require('./handlers');
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
const abortControllers = new Map();
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', requireJwtAuth, async (req, res) => {
try {
return await abortMessage(req, res, abortControllers);
} catch (err) {
console.error(err);
}
});
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const newConvo = !conversationId;
const user = req.user.id;
router.post('/', requireJwtAuth, async (req, res) => {
const { endpoint, text, parentMessageId, conversationId } = req.body;
if (text.length === 0) {
return handleError(res, { text: 'Prompt empty or too short' });
}
if (endpoint !== 'gptPlugins') {
return handleError(res, { text: 'Illegal request' });
}
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const agentOptions = req.body?.agentOptions ?? {
agent: 'functions',
skipCompletion: true,
model: 'gpt-3.5-turbo',
temperature: 0,
// top_p: 1,
// presence_penalty: 0,
// frequency_penalty: 0
};
const tools = req.body?.tools.map((tool) => tool.pluginKey) ?? [];
// build endpoint option
const endpointOption = {
chatGptLabel: tools.length === 0 ? req.body?.chatGptLabel ?? null : null,
promptPrefix: tools.length === 0 ? req.body?.promptPrefix ?? null : null,
tools,
modelOptions: {
model: req.body?.model ?? 'gpt-4',
temperature: req.body?.temperature ?? 0,
top_p: req.body?.top_p ?? 1,
presence_penalty: req.body?.presence_penalty ?? 0,
frequency_penalty: req.body?.frequency_penalty ?? 0,
},
agentOptions: {
...agentOptions,
// agent: 'functions'
},
};
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
// eslint-disable-next-line no-use-before-define
return await ask({
text,
endpoint,
endpointOption,
conversationId,
parentMessageId,
req,
res,
});
});
const ask = async ({
text,
endpoint,
endpointOption,
parentMessageId = null,
conversationId,
req,
res,
}) => {
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no',
});
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
const newConvo = !conversationId;
const { overrideParentMessageId = null } = req.body;
const user = req.user.id;
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
try {
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = userMessage.messageId;
@ -128,11 +72,11 @@ const ask = async ({
plugin.loading = false;
}
if (currentTimestamp - lastSavedTimestamp > 500) {
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: 'ChatGPT',
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
@ -142,63 +86,13 @@ const ask = async ({
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const abortController = new AbortController();
abortController.abortAsk = async function () {
this.abort();
const responseMessage = {
messageId: responseMessageId,
sender: endpointOption?.chatGptLabel || 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
model: endpointOption.modelOptions.model,
unfinished: false,
cancelled: true,
error: false,
};
saveMessage(responseMessage);
return {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage,
};
};
const onStart = (userMessage) => {
sendMessage(res, { message: userMessage, created: true });
abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption });
};
endpointOption.tools = await validateTools(user, endpointOption.tools);
const clientOptions = {
debug: true,
endpoint,
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
proxy: process.env.PROXY || null,
...endpointOption,
};
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
if (process.env.PLUGINS_USE_AZURE) {
clientOptions.azure = getAzureCredentials();
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
}
if (openAIApiKey && openAIApiKey.includes('azure') && !clientOptions.azure) {
clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials();
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
}
const chatAgent = new PluginsClient(openAIApiKey, clientOptions);
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
@ -219,70 +113,86 @@ const ask = async ({
// console.log('CHAIN END', plugin.outputs);
};
let response = await chatAgent.sendMessage(text, {
getIds,
user,
parentMessageId,
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
overrideParentMessageId,
onAgentAction,
onChainEnd,
onStart,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
plugin,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client, azure, openAIApiKey } = initializeClient(req, endpointOption);
console.log('CLIENT RESPONSE');
console.dir(response, { depth: null });
response.plugin = { ...plugin, loading: false };
await saveMessage(response);
let response = await client.sendMessage(text, {
user,
conversationId,
parentMessageId,
overrideParentMessageId,
getIds,
onAgentAction,
onChainEnd,
onStart,
addMetadata,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
plugin,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
});
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
const title = await titleConvo({
if (metadata) {
response = { ...response, ...metadata };
}
console.log('CLIENT RESPONSE');
console.dir(response, { depth: null });
response.plugin = { ...plugin, loading: false };
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
addTitle(req, {
text,
newConvo,
response,
openAIApiKey,
azure: !!clientOptions.azure,
parentMessageId,
azure: !!azure,
});
await saveConvo(req.user.id, {
conversationId: conversationId,
title,
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId,
});
}
} catch (error) {
console.error(error);
const errorMessage = {
messageId: responseMessageId,
sender: 'ChatGPT',
conversationId,
parentMessageId: userMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message,
};
await saveMessage(errorMessage);
handleError(res, errorMessage);
}
};
},
);
module.exports = router;

View file

@ -1,7 +1,5 @@
const express = require('express');
const router = express.Router();
// const askAzureOpenAI = require('./askAzureOpenAI';)
// const askOpenAI = require('./askOpenAI');
const openAI = require('./openAI');
const google = require('./google');
const bingAI = require('./bingAI');
@ -9,7 +7,6 @@ const gptPlugins = require('./gptPlugins');
const askChatGPTBrowser = require('./askChatGPTBrowser');
const anthropic = require('./anthropic');
// router.use('/azureOpenAI', askAzureOpenAI);
router.use(['/azureOpenAI', '/openAI'], openAI);
router.use('/google', google);
router.use('/bingAI', bingAI);

View file

@ -1,231 +1,160 @@
const express = require('express');
const router = express.Router();
const { titleConvo, OpenAIClient } = require('../../../app');
const { getAzureCredentials, abortMessage } = require('../../../utils');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress } = require('./handlers');
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
const { getResponseSender } = require('../endpoints/schemas');
const { sendMessage, createOnProgress } = require('../../utils');
const { addTitle, initializeClient } = require('../endpoints/openAI');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
const abortControllers = new Map();
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', requireJwtAuth, async (req, res) => {
try {
return await abortMessage(req, res, abortControllers);
} catch (err) {
console.error(err);
}
});
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const newConvo = !conversationId;
const user = req.user.id;
router.post('/', requireJwtAuth, async (req, res) => {
const { endpoint, text, parentMessageId, conversationId } = req.body;
if (text.length === 0) {
return handleError(res, { text: 'Prompt empty or too short' });
}
const isOpenAI = endpoint === 'openAI' || endpoint === 'azureOpenAI';
if (!isOpenAI) {
return handleError(res, { text: 'Illegal request' });
}
const addMetadata = (data) => (metadata = data);
// build endpoint option
const endpointOption = {
chatGptLabel: req.body?.chatGptLabel ?? null,
promptPrefix: req.body?.promptPrefix ?? null,
modelOptions: {
model: req.body?.model ?? 'gpt-3.5-turbo',
temperature: req.body?.temperature ?? 1,
top_p: req.body?.top_p ?? 1,
presence_penalty: req.body?.presence_penalty ?? 0,
frequency_penalty: req.body?.frequency_penalty ?? 0,
},
};
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
// eslint-disable-next-line no-use-before-define
return await ask({
text,
endpointOption,
conversationId,
parentMessageId,
endpoint,
req,
res,
});
});
const ask = async ({
text,
endpointOption,
parentMessageId = null,
endpoint,
conversationId,
req,
res,
}) => {
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform',
'Access-Control-Allow-Origin': '*',
'X-Accel-Buffering': 'no',
});
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
const newConvo = !conversationId;
const { overrideParentMessageId = null } = req.body;
const user = req.user.id;
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = userMessage.messageId;
responseMessageId = data.responseMessageId;
if (!conversationId) {
conversationId = data.conversationId;
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
});
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = userMessage.messageId;
responseMessageId = data.responseMessageId;
if (!conversationId) {
conversationId = data.conversationId;
}
},
});
};
const abortController = new AbortController();
abortController.abortAsk = async function () {
this.abort();
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
const responseMessage = {
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
sender: endpointOption?.chatGptLabel || 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
model: endpointOption.modelOptions.model,
unfinished: false,
cancelled: true,
error: false,
};
saveMessage(responseMessage);
return {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage,
};
};
const onStart = (userMessage) => {
sendMessage(res, { message: userMessage, created: true });
abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption });
};
try {
const clientOptions = {
// debug: true,
// contextStrategy: 'refine',
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
proxy: process.env.PROXY || null,
endpoint,
...endpointOption,
};
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
if (process.env.AZURE_API_KEY && endpoint === 'azureOpenAI') {
clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials();
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
}
const client = new OpenAIClient(openAIApiKey, clientOptions);
let response = await client.sendMessage(text, {
user,
parentMessageId,
conversationId,
overrideParentMessageId,
getIds,
onStart,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
userMessage,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
console.log(
'promptTokens, completionTokens:',
response.promptTokens,
response.completionTokens,
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
try {
const { client, openAIApiKey } = initializeClient(req, endpointOption);
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
const title = await titleConvo({
let response = await client.sendMessage(text, {
user,
parentMessageId,
conversationId,
overrideParentMessageId,
getIds,
onStart,
addMetadata,
abortController,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
console.log(
'promptTokens, completionTokens:',
response.promptTokens,
response.completionTokens,
);
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
addTitle(req, {
text,
newConvo,
response,
openAIApiKey,
azure: endpoint === 'azureOpenAI',
parentMessageId,
azure: endpointOption.endpoint === 'azureOpenAI',
});
await saveConvo(req.user.id, {
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
title,
});
}
} catch (error) {
console.error(error);
const partialText = getPartialText();
if (partialText?.length > 2) {
return await abortMessage(req, res, abortControllers);
} else {
const errorMessage = {
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
sender: 'ChatGPT',
conversationId,
parentMessageId: userMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message,
};
await saveMessage(errorMessage);
handleError(res, errorMessage);
});
}
}
};
},
);
module.exports = router;

View file

@ -7,9 +7,7 @@ const {
} = require('../controllers/AuthController');
const { loginController } = require('../controllers/auth/LoginController');
const { logoutController } = require('../controllers/auth/LogoutController');
const requireJwtAuth = require('../../middleware/requireJwtAuth');
const requireLocalAuth = require('../../middleware/requireLocalAuth');
const validateRegistration = require('../../middleware/validateRegistration');
const { requireJwtAuth, requireLocalAuth, validateRegistration } = require('../middleware');
const router = express.Router();

View file

@ -2,7 +2,7 @@ const express = require('express');
const router = express.Router();
const { getConvo, saveConvo } = require('../../models');
const { getConvosByPage, deleteConvos } = require('../../models/Conversation');
const requireJwtAuth = require('../../middleware/requireJwtAuth');
const requireJwtAuth = require('../middleware/requireJwtAuth');
router.get('/', requireJwtAuth, async (req, res) => {
const pageNumber = req.query.pageNumber || 1;

View file

@ -0,0 +1,139 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/anthropic');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
router.post('/abort', requireJwtAuth, handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
const addMetadata = (data) => (metadata = data);
const getIds = (data) => (userMessage = data.userMessage);
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
const { client } = initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user: req.user.id,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
getIds,
onStart,
addMetadata,
abortController,
});
if (metadata) {
response = { ...response, ...metadata };
}
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId,
});
}
},
);
module.exports = router;

View file

@ -0,0 +1,185 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { validateTools } = require('../../../app');
const { initializeClient } = require('../endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
router.post('/abort', requireJwtAuth, handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
const user = req.user.id;
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const addMetadata = (data) => (metadata = data);
const getIds = (data) => (userMessage = data.userMessage);
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (plugin.loading === true) {
plugin.loading = false;
}
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start) {
saveMessage(userMessage);
}
sendIntermediateMessage(res, { plugin });
// console.log('PLUGIN ACTION', formattedAction);
};
const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage(userMessage);
sendIntermediateMessage(res, { plugin });
// console.log('CHAIN END', plugin.outputs);
};
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getIds,
onAgentAction,
onChainEnd,
onStart,
addMetadata,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
plugin,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
console.log('CLIENT RESPONSE');
console.dir(response, { depth: null });
response.plugin = { ...plugin, loading: false };
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId,
});
}
},
);
module.exports = router;

View file

@ -0,0 +1,13 @@
const express = require('express');
const router = express.Router();
const openAI = require('./openAI');
const gptPlugins = require('./gptPlugins');
const anthropic = require('./anthropic');
// const google = require('./google');
router.use(['/azureOpenAI', '/openAI'], openAI);
router.use('/gptPlugins', gptPlugins);
router.use('/anthropic', anthropic);
// router.use('/google', google);
module.exports = router;

View file

@ -0,0 +1,141 @@
const express = require('express');
const router = express.Router();
const { getResponseSender } = require('../endpoints/schemas');
const { initializeClient } = require('../endpoints/openAI');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
router.post('/abort', requireJwtAuth, handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
const addMetadata = (data) => (metadata = data);
const getIds = (data) => (userMessage = data.userMessage);
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
try {
const { client } = initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user: req.user.id,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getIds,
onStart,
addMetadata,
abortController,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
});
if (metadata) {
response = { ...response, ...metadata };
}
console.log(
'promptTokens, completionTokens:',
response.promptTokens,
response.completionTokens,
);
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId,
});
}
},
);
module.exports = router;

View file

@ -3,37 +3,45 @@ const express = require('express');
const router = express.Router();
const { availableTools } = require('../../app/clients/tools');
const { addOpenAPISpecs } = require('../../app/clients/tools/util/addOpenAPISpecs');
// const { getAzureCredentials, genAzureChatCompletion } = require('../../utils/');
const openAIApiKey = process.env.OPENAI_API_KEY;
const azureOpenAIApiKey = process.env.AZURE_API_KEY;
const useAzurePlugins = !!process.env.PLUGINS_USE_AZURE;
const userProvidedOpenAI = openAIApiKey
? openAIApiKey === 'user_provided'
: azureOpenAIApiKey === 'user_provided';
const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _models = []) => {
let models = _models.slice() ?? [];
let apiKey = openAIApiKey;
let basePath = 'https://api.openai.com/v1';
if (opts.azure) {
/* TODO: Add Azure models from api/models */
return models;
// const azure = getAzureCredentials();
// basePath = (genAzureChatCompletion(azure))
// .split('/deployments')[0]
// .concat(`/models?api-version=${azure.azureOpenAIApiVersion}`);
// apiKey = azureOpenAIApiKey;
}
let basePath = 'https://api.openai.com/v1/';
const reverseProxyUrl = process.env.OPENAI_REVERSE_PROXY;
if (reverseProxyUrl) {
basePath = reverseProxyUrl.match(/.*v1/)[0];
}
if (basePath.includes('v1')) {
if (basePath.includes('v1') || opts.azure) {
try {
const res = await axios.get(`${basePath}/models`, {
const res = await axios.get(`${basePath}${opts.azure ? '' : '/models'}`, {
headers: {
Authorization: `Bearer ${openAIApiKey}`,
Authorization: `Bearer ${apiKey}`,
},
});
models = res.data.data.map((item) => item.id);
// console.log(`Fetched ${models.length} models from ${opts.azure ? 'Azure ' : ''}OpenAI API`);
} catch (err) {
console.error(err);
console.log(`Failed to fetch models from ${opts.azure ? 'Azure ' : ''}OpenAI API`);
}
}
@ -149,7 +157,7 @@ router.get('/', async function (req, res) {
const gptPlugins =
openAIApiKey || azureOpenAIApiKey
? {
availableModels: await getOpenAIModels({ plugins: true }),
availableModels: await getOpenAIModels({ azure: useAzurePlugins, plugins: true }),
plugins,
availableAgents: ['classic', 'functions'],
userProvide: userProvidedOpenAI,

View file

@ -0,0 +1,15 @@
const buildOptions = (endpoint, parsedBody) => {
const { modelLabel, promptPrefix, ...rest } = parsedBody;
const endpointOption = {
endpoint,
modelLabel,
promptPrefix,
modelOptions: {
...rest,
},
};
return endpointOption;
};
module.exports = buildOptions;

View file

@ -0,0 +1,8 @@
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');
module.exports = {
// addTitle, // todo
buildOptions,
initializeClient,
};

View file

@ -0,0 +1,12 @@
const { AnthropicClient } = require('../../../../app');
const initializeClient = (req) => {
let anthropicApiKey = req.body?.token ?? process.env.ANTHROPIC_API_KEY;
const client = new AnthropicClient(anthropicApiKey);
return {
client,
anthropicApiKey,
};
};
module.exports = initializeClient;

View file

@ -0,0 +1,31 @@
const buildOptions = (endpoint, parsedBody) => {
const {
chatGptLabel,
promptPrefix,
agentOptions,
tools,
model,
temperature,
top_p,
presence_penalty,
frequency_penalty,
} = parsedBody;
const endpointOption = {
endpoint,
tools: tools.map((tool) => tool.pluginKey) ?? [],
chatGptLabel,
promptPrefix,
agentOptions,
modelOptions: {
model,
temperature,
top_p,
presence_penalty,
frequency_penalty,
},
};
return endpointOption;
};
module.exports = buildOptions;

View file

@ -0,0 +1,7 @@
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');
module.exports = {
buildOptions,
initializeClient,
};

View file

@ -0,0 +1,30 @@
const { PluginsClient } = require('../../../../app');
const { getAzureCredentials } = require('../../../../utils');
const initializeClient = (req, endpointOption) => {
const clientOptions = {
debug: true,
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
proxy: process.env.PROXY || null,
...endpointOption,
};
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
if (process.env.PLUGINS_USE_AZURE) {
clientOptions.azure = getAzureCredentials();
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
}
if (openAIApiKey && openAIApiKey.includes('azure') && !clientOptions.azure) {
clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials();
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
}
const client = new PluginsClient(openAIApiKey, clientOptions);
return {
client,
azure: clientOptions.azure,
openAIApiKey,
};
};
module.exports = initializeClient;

View file

@ -0,0 +1,22 @@
const { titleConvo } = require('../../../../app');
const { saveConvo } = require('../../../../models');
const addTitle = async (
req,
{ text, azure, response, newConvo, parentMessageId, openAIApiKey },
) => {
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
const title = await titleConvo({
text,
azure,
response,
openAIApiKey,
});
await saveConvo(req.user.id, {
conversationId: response.conversationId,
title,
});
}
};
module.exports = addTitle;

View file

@ -0,0 +1,15 @@
const buildOptions = (endpoint, parsedBody) => {
const { chatGptLabel, promptPrefix, ...rest } = parsedBody;
const endpointOption = {
endpoint,
chatGptLabel,
promptPrefix,
modelOptions: {
...rest,
},
};
return endpointOption;
};
module.exports = buildOptions;

View file

@ -0,0 +1,9 @@
const addTitle = require('./addTitle');
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');
module.exports = {
addTitle,
buildOptions,
initializeClient,
};

View file

@ -0,0 +1,27 @@
const { OpenAIClient } = require('../../../../app');
const { getAzureCredentials } = require('../../../../utils');
const initializeClient = (req, endpointOption) => {
const clientOptions = {
// debug: true,
// contextStrategy: 'refine',
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
proxy: process.env.PROXY || null,
...endpointOption,
};
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
if (process.env.AZURE_API_KEY && endpointOption.endpoint === 'azureOpenAI') {
clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials();
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
}
const client = new OpenAIClient(openAIApiKey, clientOptions);
return {
client,
openAIApiKey,
};
};
module.exports = initializeClient;

View file

@ -0,0 +1,369 @@
const { z } = require('zod');
const EModelEndpoint = {
azureOpenAI: 'azureOpenAI',
openAI: 'openAI',
bingAI: 'bingAI',
chatGPTBrowser: 'chatGPTBrowser',
google: 'google',
gptPlugins: 'gptPlugins',
anthropic: 'anthropic',
};
const eModelEndpointSchema = z.nativeEnum(EModelEndpoint);
/*
const tMessageSchema = z.object({
messageId: z.string(),
clientId: z.string().nullable().optional(),
conversationId: z.string().nullable(),
parentMessageId: z.string().nullable(),
sender: z.string(),
text: z.string(),
isCreatedByUser: z.boolean(),
error: z.boolean(),
createdAt: z
.string()
.optional()
.default(() => new Date().toISOString()),
updatedAt: z
.string()
.optional()
.default(() => new Date().toISOString()),
current: z.boolean().optional(),
unfinished: z.boolean().optional(),
submitting: z.boolean().optional(),
searchResult: z.boolean().optional(),
finish_reason: z.string().optional(),
});
const tPresetSchema = tConversationSchema
.omit({
conversationId: true,
createdAt: true,
updatedAt: true,
title: true,
})
.merge(
z.object({
conversationId: z.string().optional(),
presetId: z.string().nullable().optional(),
title: z.string().nullable().optional(),
}),
);
*/
const tPluginAuthConfigSchema = z.object({
authField: z.string(),
label: z.string(),
description: z.string(),
});
const tPluginSchema = z.object({
name: z.string(),
pluginKey: z.string(),
description: z.string(),
icon: z.string(),
authConfig: z.array(tPluginAuthConfigSchema),
authenticated: z.boolean().optional(),
isButton: z.boolean().optional(),
});
const tExampleSchema = z.object({
input: z.object({
content: z.string(),
}),
output: z.object({
content: z.string(),
}),
});
const tAgentOptionsSchema = z.object({
agent: z.string(),
skipCompletion: z.boolean(),
model: z.string(),
temperature: z.number(),
});
const tConversationSchema = z.object({
conversationId: z.string().nullable(),
title: z.string(),
user: z.string().optional(),
endpoint: eModelEndpointSchema.nullable(),
suggestions: z.array(z.string()).optional(),
messages: z.array(z.string()).optional(),
tools: z.array(tPluginSchema).optional(),
createdAt: z.string(),
updatedAt: z.string(),
systemMessage: z.string().nullable().optional(),
modelLabel: z.string().nullable().optional(),
examples: z.array(tExampleSchema).optional(),
chatGptLabel: z.string().nullable().optional(),
userLabel: z.string().optional(),
model: z.string().nullable().optional(),
promptPrefix: z.string().nullable().optional(),
temperature: z.number().optional(),
topP: z.number().optional(),
topK: z.number().optional(),
context: z.string().nullable().optional(),
top_p: z.number().optional(),
frequency_penalty: z.number().optional(),
presence_penalty: z.number().optional(),
jailbreak: z.boolean().optional(),
jailbreakConversationId: z.string().nullable().optional(),
conversationSignature: z.string().nullable().optional(),
parentMessageId: z.string().optional(),
clientId: z.string().nullable().optional(),
invocationId: z.number().nullable().optional(),
toneStyle: z.string().nullable().optional(),
maxOutputTokens: z.number().optional(),
agentOptions: tAgentOptionsSchema.nullable().optional(),
});
const openAISchema = tConversationSchema
.pick({
model: true,
chatGptLabel: true,
promptPrefix: true,
temperature: true,
top_p: true,
presence_penalty: true,
frequency_penalty: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'gpt-3.5-turbo',
chatGptLabel: obj.chatGptLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 1,
top_p: obj.top_p ?? 1,
presence_penalty: obj.presence_penalty ?? 0,
frequency_penalty: obj.frequency_penalty ?? 0,
}))
.catch(() => ({
model: 'gpt-3.5-turbo',
chatGptLabel: null,
promptPrefix: null,
temperature: 1,
top_p: 1,
presence_penalty: 0,
frequency_penalty: 0,
}));
const googleSchema = tConversationSchema
.pick({
model: true,
modelLabel: true,
promptPrefix: true,
examples: true,
temperature: true,
maxOutputTokens: true,
topP: true,
topK: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'chat-bison',
modelLabel: obj.modelLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 0.2,
maxOutputTokens: obj.maxOutputTokens ?? 1024,
topP: obj.topP ?? 0.95,
topK: obj.topK ?? 40,
}))
.catch(() => ({
model: 'chat-bison',
modelLabel: null,
promptPrefix: null,
temperature: 0.2,
maxOutputTokens: 1024,
topP: 0.95,
topK: 40,
}));
const bingAISchema = tConversationSchema
.pick({
jailbreak: true,
systemMessage: true,
context: true,
toneStyle: true,
jailbreakConversationId: true,
conversationSignature: true,
clientId: true,
invocationId: true,
})
.transform((obj) => ({
...obj,
model: '',
jailbreak: obj.jailbreak ?? false,
systemMessage: obj.systemMessage ?? null,
context: obj.context ?? null,
toneStyle: obj.toneStyle ?? 'creative',
jailbreakConversationId: obj.jailbreakConversationId ?? null,
conversationSignature: obj.conversationSignature ?? null,
clientId: obj.clientId ?? null,
invocationId: obj.invocationId ?? 1,
}))
.catch(() => ({
model: '',
jailbreak: false,
systemMessage: null,
context: null,
toneStyle: 'creative',
jailbreakConversationId: null,
conversationSignature: null,
clientId: null,
invocationId: 1,
}));
const anthropicSchema = tConversationSchema
.pick({
model: true,
modelLabel: true,
promptPrefix: true,
temperature: true,
maxOutputTokens: true,
topP: true,
topK: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'claude-1',
modelLabel: obj.modelLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 1,
maxOutputTokens: obj.maxOutputTokens ?? 1024,
topP: obj.topP ?? 0.7,
topK: obj.topK ?? 5,
}))
.catch(() => ({
model: 'claude-1',
modelLabel: null,
promptPrefix: null,
temperature: 1,
maxOutputTokens: 1024,
topP: 0.7,
topK: 5,
}));
const chatGPTBrowserSchema = tConversationSchema
.pick({
model: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'text-davinci-002-render-sha',
}))
.catch(() => ({
model: 'text-davinci-002-render-sha',
}));
const gptPluginsSchema = tConversationSchema
.pick({
model: true,
chatGptLabel: true,
promptPrefix: true,
temperature: true,
top_p: true,
presence_penalty: true,
frequency_penalty: true,
tools: true,
agentOptions: true,
})
.transform((obj) => ({
...obj,
model: obj.model ?? 'gpt-3.5-turbo',
chatGptLabel: obj.chatGptLabel ?? null,
promptPrefix: obj.promptPrefix ?? null,
temperature: obj.temperature ?? 0.8,
top_p: obj.top_p ?? 1,
presence_penalty: obj.presence_penalty ?? 0,
frequency_penalty: obj.frequency_penalty ?? 0,
tools: obj.tools ?? [],
agentOptions: obj.agentOptions ?? {
agent: 'functions',
skipCompletion: true,
model: 'gpt-3.5-turbo',
temperature: 0,
},
}))
.catch(() => ({
model: 'gpt-3.5-turbo',
chatGptLabel: null,
promptPrefix: null,
temperature: 0.8,
top_p: 1,
presence_penalty: 0,
frequency_penalty: 0,
tools: [],
agentOptions: {
agent: 'functions',
skipCompletion: true,
model: 'gpt-3.5-turbo',
temperature: 0,
},
}));
const endpointSchemas = {
openAI: openAISchema,
azureOpenAI: openAISchema,
google: googleSchema,
bingAI: bingAISchema,
anthropic: anthropicSchema,
chatGPTBrowser: chatGPTBrowserSchema,
gptPlugins: gptPluginsSchema,
};
function getFirstDefinedValue(possibleValues) {
let returnValue;
for (const value of possibleValues) {
if (value) {
returnValue = value;
break;
}
}
return returnValue;
}
const parseConvo = (endpoint, conversation, possibleValues) => {
const schema = endpointSchemas[endpoint];
if (!schema) {
throw new Error(`Unknown endpoint: ${endpoint}`);
}
const convo = schema.parse(conversation);
if (possibleValues && convo) {
convo.model = getFirstDefinedValue(possibleValues.model) ?? convo.model;
}
return convo;
};
const getResponseSender = (endpointOption) => {
const { endpoint, chatGptLabel, modelLabel, jailbreak } = endpointOption;
if (['openAI', 'azureOpenAI', 'gptPlugins', 'chatGPTBrowser'].includes(endpoint)) {
return chatGptLabel ?? 'ChatGPT';
}
if (endpoint === 'bingAI') {
return jailbreak ? 'Sydney' : 'BingAI';
}
if (endpoint === 'anthropic') {
return modelLabel ?? 'Anthropic';
}
if (endpoint === 'google') {
return modelLabel ?? 'PaLM2';
}
return '';
};
module.exports = {
parseConvo,
getResponseSender,
};

View file

@ -1,4 +1,5 @@
const ask = require('./ask');
const edit = require('./edit');
const messages = require('./messages');
const convos = require('./convos');
const presets = require('./presets');
@ -15,6 +16,7 @@ const config = require('./config');
module.exports = {
search,
ask,
edit,
messages,
convos,
presets,

View file

@ -1,7 +1,7 @@
const express = require('express');
const router = express.Router();
const { getMessages } = require('../../models/Message');
const requireJwtAuth = require('../../middleware/requireJwtAuth');
const requireJwtAuth = require('../middleware/requireJwtAuth');
router.get('/:conversationId', requireJwtAuth, async (req, res) => {
const { conversationId } = req.params;

View file

@ -1,6 +1,6 @@
const express = require('express');
const { getAvailablePluginsController } = require('../controllers/PluginController');
const requireJwtAuth = require('../../middleware/requireJwtAuth');
const requireJwtAuth = require('../middleware/requireJwtAuth');
const router = express.Router();

View file

@ -2,7 +2,7 @@ const express = require('express');
const router = express.Router();
const { getPresets, savePreset, deletePresets } = require('../../models');
const crypto = require('crypto');
const requireJwtAuth = require('../../middleware/requireJwtAuth');
const requireJwtAuth = require('../middleware/requireJwtAuth');
router.get('/', requireJwtAuth, async (req, res) => {
const presets = (await getPresets(req.user.id)).map((preset) => {

View file

@ -5,7 +5,7 @@ const { Message } = require('../../models/Message');
const { Conversation, getConvosQueried } = require('../../models/Conversation');
const { reduceHits } = require('../../lib/utils/reduceHits');
const { cleanUpPrimaryKeyValue } = require('../../lib/utils/misc');
const requireJwtAuth = require('../../middleware/requireJwtAuth');
const requireJwtAuth = require('../middleware/requireJwtAuth');
const cache = new Map();

View file

@ -4,7 +4,7 @@ const { Tiktoken } = require('@dqbd/tiktoken/lite');
const { load } = require('@dqbd/tiktoken/load');
const registry = require('@dqbd/tiktoken/registry.json');
const models = require('@dqbd/tiktoken/model_to_encoding.json');
const requireJwtAuth = require('../../middleware/requireJwtAuth');
const requireJwtAuth = require('../middleware/requireJwtAuth');
router.post('/', requireJwtAuth, async (req, res) => {
try {

View file

@ -1,5 +1,5 @@
const express = require('express');
const requireJwtAuth = require('../../middleware/requireJwtAuth');
const requireJwtAuth = require('../middleware/requireJwtAuth');
const { getUserController, updateUserPluginsController } = require('../controllers/UserController');
const router = express.Router();

View file

@ -1,10 +1,10 @@
const User = require('../../models/User');
const Token = require('../../models/schema/tokenSchema');
const crypto = require('crypto');
const bcrypt = require('bcryptjs');
const User = require('../../models/User');
const Token = require('../../models/schema/tokenSchema');
const { registerSchema } = require('../../strategies/validators');
const { sendEmail } = require('../../utils');
const config = require('../../../config/loader');
const { sendEmail } = require('../utils');
const domains = config.domains;
/**

View file

@ -1,5 +1,5 @@
const PluginAuth = require('../../models/schema/pluginAuthSchema');
const { encrypt, decrypt } = require('../../utils/');
const { encrypt, decrypt } = require('../utils/');
const getUserPluginAuthValue = async (user, authField) => {
try {

View file

@ -1,4 +1,19 @@
const citationRegex = /\[\^\d+?\^\]/g;
const regex = / \[.*?]\(.*?\)/g;
const getCitations = (res) => {
const adaptiveCards = res.details.adaptiveCards;
const textBlocks = adaptiveCards && adaptiveCards[0].body;
if (!textBlocks) {
return '';
}
let links = textBlocks[textBlocks.length - 1]?.text.match(regex);
if (links?.length === 0 || !links) {
return '';
}
links = links.map((link) => link.trim());
return links.join('\n - ');
};
const citeText = (res, noLinks = false) => {
let result = res.text || res;
@ -32,4 +47,4 @@ const citeText = (res, noLinks = false) => {
return result;
};
module.exports = citeText;
module.exports = { getCitations, citeText };

View file

@ -1,8 +1,10 @@
const _ = require('lodash');
const citationRegex = /\[\^\d+?\^]/g;
const { getCitations, citeText } = require('../../../app');
const { getCitations, citeText } = require('./citations');
const cursor = '<span className="result-streaming">█</span>';
const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text);
const handleError = (res, message) => {
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
res.end();
@ -15,12 +17,12 @@ const sendMessage = (res, message, event = 'message') => {
res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`);
};
const createOnProgress = ({ onProgress: _onProgress }) => {
const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
let i = 0;
let code = '';
let tokens = '';
let precode = '';
let codeBlock = false;
let tokens = addSpaceIfNeeded(generation);
const progressCallback = async (partial, { res, text, plugin, bing = false, ...rest }) => {
let chunk = partial === text ? '' : partial;
@ -155,4 +157,5 @@ module.exports = {
handleText,
formatSteps,
formatAction,
addSpaceIfNeeded,
};

11
api/server/utils/index.js Normal file
View file

@ -0,0 +1,11 @@
const cryptoUtils = require('./crypto');
const handleText = require('./handleText');
const citations = require('./citations');
const sendEmail = require('./sendEmail');
module.exports = {
...cryptoUtils,
...handleText,
...citations,
sendEmail,
};

View file

@ -1,18 +0,0 @@
async function abortMessage(req, res, abortControllers) {
const { abortKey } = req.body;
console.log('req.body', req.body);
if (!abortControllers.has(abortKey)) {
return res.status(404).send('Request not found');
}
const { abortController } = abortControllers.get(abortKey);
abortControllers.delete(abortKey);
const ret = await abortController.abortAsk();
console.log('Aborted request', abortKey);
console.log('Aborted message:', ret);
res.send(JSON.stringify(ret));
}
module.exports = abortMessage;

View file

@ -1,16 +1,10 @@
const azureUtils = require('./azureUtils');
const cryptoUtils = require('./crypto');
const { tiktokenModels, maxTokensMap } = require('./tokens');
const sendEmail = require('./sendEmail');
const abortMessage = require('./abortMessage');
const findMessageContent = require('./findMessageContent');
module.exports = {
...cryptoUtils,
...azureUtils,
maxTokensMap,
tiktokenModels,
sendEmail,
abortMessage,
findMessageContent,
};

View file

@ -48,3 +48,17 @@ export type TSetOptionsPayload = {
checkPluginSelection: (value: string) => boolean;
setTools: (newValue: string) => void;
};
export type TPresetItemProps = {
preset: TPreset;
value: TPreset;
onSelect: (preset: TPreset) => void;
onChangePreset: (preset: TPreset) => void;
onDeletePreset: (preset: TPreset) => void;
};
export type TOnClick = (e: React.MouseEvent<HTMLButtonElement>) => void;
export type TGenButtonProps = {
onClick: TOnClick;
};

View file

@ -103,10 +103,18 @@ export default function NewConversationMenu() {
};
// set the current model
const isModular = modularEndpoints.has(endpoint);
const onSelectPreset = (newPreset) => {
setMenuOpen(false);
if (!newPreset) {
return;
}
if (modularEndpoints.has(endpoint) && modularEndpoints.has(newPreset?.endpoint)) {
if (
isModular &&
modularEndpoints.has(newPreset?.endpoint) &&
endpoint === newPreset?.endpoint
) {
const currentConvo = getDefaultConversation({
conversation,
endpointsConfig,
@ -118,10 +126,6 @@ export default function NewConversationMenu() {
return;
}
if (!newPreset) {
return;
}
newConversation({}, newPreset);
};

View file

@ -1,7 +1,14 @@
import type { TPresetItemProps } from '~/common';
import type { TPreset } from 'librechat-data-provider';
import { DropdownMenuRadioItem, EditIcon, TrashIcon } from '~/components';
import { getIcon } from '~/components/Endpoints';
export default function PresetItem({ preset = {}, value, onChangePreset, onDeletePreset }) {
export default function PresetItem({
preset = {} as TPreset,
value,
onChangePreset,
onDeletePreset,
}: TPresetItemProps) {
const { endpoint } = preset;
const icon = getIcon({
@ -14,9 +21,9 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
const getPresetTitle = () => {
let _title = `${endpoint}`;
const { chatGptLabel, modelLabel, model, jailbreak, toneStyle } = preset;
if (endpoint === 'azureOpenAI' || endpoint === 'openAI') {
const { chatGptLabel, model } = preset;
if (model) {
_title += `: ${model}`;
}
@ -24,7 +31,6 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
_title += ` as ${chatGptLabel}`;
}
} else if (endpoint === 'google') {
const { modelLabel, model } = preset;
if (model) {
_title += `: ${model}`;
}
@ -32,7 +38,6 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
_title += ` as ${modelLabel}`;
}
} else if (endpoint === 'bingAI') {
const { jailbreak, toneStyle } = preset;
if (toneStyle) {
_title += `: ${toneStyle}`;
}
@ -40,12 +45,10 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
_title += ' as Sydney';
}
} else if (endpoint === 'chatGPTBrowser') {
const { model } = preset;
if (model) {
_title += `: ${model}`;
}
} else if (endpoint === 'gptPlugins') {
const { model } = preset;
if (model) {
_title += `: ${model}`;
}
@ -60,6 +63,7 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
// regular model
return (
<DropdownMenuRadioItem
/* @ts-ignore, value can be an object as well */
value={value}
className="group flex h-10 max-h-[44px] flex-row justify-between dark:font-semibold dark:text-gray-100 dark:hover:bg-gray-800 sm:h-auto"
>

View file

@ -1,10 +1,11 @@
import React from 'react';
import PresetItem from './PresetItem';
import type { TPreset } from 'librechat-data-provider';
export default function PresetItems({ presets, onSelect, onChangePreset, onDeletePreset }) {
return (
<>
{presets.map((preset) => (
{presets.map((preset: TPreset) => (
<PresetItem
key={preset?.presetId ?? Math.random()}
value={preset}

View file

@ -0,0 +1,27 @@
import { cn, removeFocusOutlines } from '~/utils/';
export default function Button({
type = 'regenerate',
children,
onClick,
className = '',
}: {
type?: 'regenerate' | 'continue' | 'stop';
children: React.ReactNode;
onClick: (e: React.MouseEvent<HTMLButtonElement>) => void;
className?: string;
}) {
return (
<button
data-testid={`${type}-generation-button`}
className={cn(
'custom-btn btn-neutral relative -z-0 whitespace-nowrap border-0 md:border',
removeFocusOutlines,
className,
)}
onClick={onClick}
>
<div className="flex w-full items-center justify-center gap-2">{children}</div>
</button>
);
}

View file

@ -0,0 +1,12 @@
import type { TGenButtonProps } from '~/common';
import { ContinueIcon } from '~/components/svg';
import Button from './Button';
export default function Continue({ onClick }: TGenButtonProps) {
return (
<Button type="continue" onClick={onClick}>
<ContinueIcon className="text-gray-600/90" />
Continue
</Button>
);
}

View file

@ -0,0 +1,61 @@
import type { TMessage } from 'librechat-data-provider';
import { useMessageHandler, useMediaQuery, useGenerations } from '~/hooks';
import { cn } from '~/utils';
import Regenerate from './Regenerate';
import Continue from './Continue';
import Stop from './Stop';
type GenerationButtonsProps = {
endpoint: string;
showPopover: boolean;
opacityClass: string;
};
export default function GenerationButtons({
endpoint,
showPopover,
opacityClass,
}: GenerationButtonsProps) {
const {
messages,
isSubmitting,
latestMessage,
handleContinue,
handleRegenerate,
handleStopGenerating,
} = useMessageHandler();
const isSmallScreen = useMediaQuery('(max-width: 768px)');
const { continueSupported, regenerateEnabled } = useGenerations({
endpoint,
message: latestMessage as TMessage,
isSubmitting,
});
if (isSmallScreen) {
return null;
}
let button: React.ReactNode = null;
if (isSubmitting) {
button = <Stop onClick={handleStopGenerating} />;
} else if (continueSupported) {
button = <Continue onClick={handleContinue} />;
} else if (messages && messages.length > 0 && regenerateEnabled) {
button = <Regenerate onClick={handleRegenerate} />;
}
return (
<div className="absolute bottom-4 right-0 z-[62]">
<div className="grow" />
<div className="flex items-center md:items-end">
<div
className={cn('option-buttons', showPopover ? '' : opacityClass)}
data-projection-id="173"
>
{button}
</div>
</div>
</div>
);
}

View file

@ -0,0 +1,12 @@
import type { TGenButtonProps } from '~/common';
import { RegenerateIcon } from '~/components/svg';
import Button from './Button';
export default function Regenerate({ onClick }: TGenButtonProps) {
return (
<Button onClick={onClick}>
<RegenerateIcon className="h-3 w-3 flex-shrink-0 text-gray-600/90" />
Regenerate
</Button>
);
}

View file

@ -0,0 +1,12 @@
import type { TGenButtonProps } from '~/common';
import { StopGeneratingIcon } from '~/components/svg';
import Button from './Button';
export default function Stop({ onClick }: TGenButtonProps) {
return (
<Button type="stop" onClick={onClick}>
<StopGeneratingIcon className="text-gray-600/90" />
Stop
</Button>
);
}

View file

@ -0,0 +1 @@
export { default as GenerationButtons } from './GenerationButtons';

View file

@ -12,7 +12,7 @@ import { Button } from '~/components/ui';
import { cn, cardStyle } from '~/utils/';
import { useSetOptions } from '~/hooks';
import { ModelSelect } from './ModelSelect';
import GenerationButtons from './GenerationButtons';
import { GenerationButtons } from './Generations';
import store from '~/store';
export default function OptionsBar() {
@ -76,7 +76,11 @@ export default function OptionsBar() {
: () => setShowPopover((prev) => !prev);
return (
<div className="relative py-2 last:mb-2 md:mx-4 md:mb-[-16px] md:py-4 md:pt-2 md:last:mb-6 lg:mx-auto lg:mb-[-32px] lg:max-w-2xl lg:pt-6 xl:max-w-3xl">
<GenerationButtons showPopover={showPopover} opacityClass={opacityClass} />
<GenerationButtons
endpoint={endpoint}
showPopover={showPopover}
opacityClass={opacityClass}
/>
<span className="flex w-full flex-col items-center justify-center gap-0 md:order-none md:m-auto md:gap-2">
<div
className={cn(

View file

@ -1,16 +0,0 @@
import React from 'react';
export default function RowButton({ onClick, children, text, className }) {
return (
<button
onClick={onClick}
className={`input-panel-button btn btn-neutral flex justify-center gap-2 border-0 md:border ${className}`}
type="button"
>
{children}
<span className="hidden md:block">{text}</span>
{/* <RegenerateIcon />
<span className="hidden md:block">Regenerate response</span> */}
</button>
);
}

View file

@ -10,22 +10,18 @@ import { cn } from '~/utils';
import store from '~/store';
export default function TextChat({ isSearchView = false }) {
const inputRef = useRef(null);
const isComposing = useRef(false);
const { ask, isSubmitting, handleStopGenerating, latestMessage, endpointsConfig } =
useMessageHandler();
const conversation = useRecoilValue(store.conversation);
const setShowBingToneSetting = useSetRecoilState(store.showBingToneSetting);
const [text, setText] = useRecoilState(store.text);
const { theme } = useContext(ThemeContext);
const conversation = useRecoilValue(store.conversation);
const latestMessage = useRecoilValue(store.latestMessage);
const endpointsConfig = useRecoilValue(store.endpointsConfig);
const isSubmitting = useRecoilValue(store.isSubmitting);
const setShowBingToneSetting = useSetRecoilState(store.showBingToneSetting);
const isComposing = useRef(false);
const inputRef = useRef(null);
// TODO: do we need this?
const disabled = false;
const { ask, stopGenerating } = useMessageHandler();
const isNotAppendable = latestMessage?.unfinished & !isSubmitting || latestMessage?.error;
const { conversationId, jailbreak } = conversation || {};
@ -60,11 +56,6 @@ export default function TextChat({ isSearchView = false }) {
setText('');
};
const handleStopGenerating = (e) => {
e.preventDefault();
stopGenerating();
};
const handleKeyDown = (e) => {
if (e.key === 'Enter' && isSubmitting) {
return;

View file

@ -1,87 +0,0 @@
import React from 'react';
import { cn } from '~/utils/';
import Clipboard from '../svg/Clipboard';
import CheckMark from '../svg/CheckMark';
import EditIcon from '../svg/EditIcon';
import RegenerateIcon from '../svg/RegenerateIcon';
export default function HoverButtons({
isEditting,
enterEdit,
copyToClipboard,
conversation,
isSubmitting,
message,
regenerate,
}) {
const { endpoint } = conversation;
const [isCopied, setIsCopied] = React.useState(false);
const branchingSupported =
// azureOpenAI, openAI, chatGPTBrowser support branching, so edit enabled // 5/21/23: Bing is allowing editing and Message regenerating
!![
'azureOpenAI',
'openAI',
'chatGPTBrowser',
'google',
'bingAI',
'gptPlugins',
'anthropic',
].find((e) => e === endpoint);
// Sydney in bingAI supports branching, so edit enabled
const editEnabled =
!message?.error &&
message?.isCreatedByUser &&
!message?.searchResult &&
!isEditting &&
branchingSupported;
// for now, once branching is supported, regerate will be enabled
let regenerateEnabled =
// !message?.error &&
!message?.isCreatedByUser &&
!message?.searchResult &&
!isEditting &&
!isSubmitting &&
branchingSupported;
return (
<div className="visible mt-2 flex justify-center gap-3 self-end text-gray-400 md:gap-4 lg:absolute lg:right-0 lg:top-0 lg:mt-0 lg:translate-x-full lg:gap-1 lg:self-center lg:pl-2">
{editEnabled ? (
<button
className="hover-button rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible"
onClick={enterEdit}
type="button"
title="edit"
>
{/* <button className="rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"> */}
<EditIcon />
</button>
) : null}
{regenerateEnabled ? (
<button
className="hover-button active rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible"
onClick={regenerate}
type="button"
title="regenerate"
>
{/* <button className="rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"> */}
<RegenerateIcon />
</button>
) : null}
<button
className={cn(
'hover-button rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible',
message?.isCreatedByUser ? '' : 'active',
)}
onClick={() => copyToClipboard(setIsCopied)}
type="button"
title={isCopied ? 'Copied to clipboard' : 'Copy to clipboard'}
>
{isCopied ? <CheckMark /> : <Clipboard />}
</button>
</div>
);
}

View file

@ -0,0 +1,86 @@
import { useState } from 'react';
import type { TConversation, TMessage } from 'librechat-data-provider';
import { Clipboard, CheckMark, EditIcon, RegenerateIcon, ContinueIcon } from '~/components/svg';
import { useGenerations } from '~/hooks';
import { cn } from '~/utils';
type THoverButtons = {
isEditing: boolean;
enterEdit: () => void;
copyToClipboard: (setIsCopied: (isCopied: boolean) => void) => void;
conversation: TConversation;
isSubmitting: boolean;
message: TMessage;
regenerate: () => void;
handleContinue: (e: React.MouseEvent<HTMLButtonElement>) => void;
};
export default function HoverButtons({
isEditing,
enterEdit,
copyToClipboard,
conversation,
isSubmitting,
message,
regenerate,
handleContinue,
}: THoverButtons) {
const { endpoint } = conversation;
const [isCopied, setIsCopied] = useState(false);
const { editEnabled, regenerateEnabled, continueSupported } = useGenerations({
isEditing,
isSubmitting,
message,
endpoint: endpoint ?? '',
});
return (
<div className="visible mt-2 flex justify-center gap-3 self-end text-gray-400 md:gap-4 lg:absolute lg:right-0 lg:top-0 lg:mt-0 lg:translate-x-full lg:gap-1 lg:self-center lg:pl-2">
<button
className={cn(
'hover-button rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible',
message?.isCreatedByUser ? '' : 'opacity-0',
)}
onClick={enterEdit}
type="button"
title="edit"
disabled={!editEnabled}
>
{/* <button className="rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"> */}
<EditIcon />
</button>
<button
className={cn(
'hover-button rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible',
message?.isCreatedByUser ? '' : 'active',
)}
onClick={() => copyToClipboard(setIsCopied)}
type="button"
title={isCopied ? 'Copied to clipboard' : 'Copy to clipboard'}
>
{isCopied ? <CheckMark /> : <Clipboard />}
</button>
{regenerateEnabled ? (
<button
className="hover-button active rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible"
onClick={regenerate}
type="button"
title="regenerate"
>
{/* <button className="rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"> */}
<RegenerateIcon />
</button>
) : null}
{continueSupported ? (
<button
className="hover-button active rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible "
onClick={handleContinue}
type="button"
title="continue"
>
<ContinueIcon className="h-4 w-4" />
</button>
) : null}
</div>
);
}

View file

@ -1,6 +1,6 @@
/* eslint-disable react-hooks/exhaustive-deps */
import { useState, useEffect, useRef } from 'react';
import { useRecoilValue, useSetRecoilState } from 'recoil';
import { useSetRecoilState } from 'recoil';
import copy from 'copy-to-clipboard';
import Plugin from './Plugin';
import SubRow from './Content/SubRow';
@ -25,13 +25,12 @@ export default function Message({
setSiblingIdx,
}) {
const { text, searchResult, isCreatedByUser, error, submitting, unfinished } = message;
const isSubmitting = useRecoilValue(store.isSubmitting);
const setLatestMessage = useSetRecoilState(store.latestMessage);
const [abortScroll, setAbort] = useState(false);
const textEditor = useRef(null);
const last = !message?.children?.length;
const edit = message.messageId == currentEditId;
const { ask, regenerate } = useMessageHandler();
const { isSubmitting, ask, regenerate, handleContinue } = useMessageHandler();
const { switchToConversation } = store.useConversation();
const blinker = submitting && isSubmitting;
const getConversationQuery = useGetConversationByIdQuery(message.conversationId, {
@ -223,12 +222,13 @@ export default function Message({
)}
</div>
<HoverButtons
isEditting={edit}
isEditing={edit}
isSubmitting={isSubmitting}
message={message}
conversation={conversation}
enterEdit={() => enterEdit()}
regenerate={() => regenerateMessage()}
handleContinue={handleContinue}
copyToClipboard={copyToClipboard}
/>
<SubRow subclasses="switch-container">

View file

@ -1,10 +1,12 @@
import React from 'react';
type Props = {
scrollHandler: React.MouseEventHandler<HTMLButtonElement>;
};
export default function ScrollToBottom({ scrollHandler }) {
export default function ScrollToBottom({ scrollHandler }: Props) {
return (
<button
onClick={scrollHandler}
className="absolute bottom-[124px] right-6 z-[62] cursor-pointer rounded-full border border-gray-200 bg-gray-50 text-gray-600 dark:border-white/10 dark:bg-white/10 dark:text-gray-200 md:bottom-[120px]"
className="absolute bottom-[124px] right-6 z-[62] cursor-pointer rounded-full border border-gray-200 bg-gray-50 text-gray-600 dark:border-white/10 dark:bg-white/10 dark:text-gray-200 md:bottom-[180px] lg:bottom-[120px]"
>
<svg
stroke="currentColor"

View file

@ -9,7 +9,7 @@ export default function Clipboard() {
viewBox="0 0 24 24"
strokeLinecap="round"
strokeLinejoin="round"
className="h-4 w-4"
className="h-4 w-4 text-gray-600 dark:text-gray-400"
height="1em"
width="1em"
xmlns="http://www.w3.org/2000/svg"

View file

@ -0,0 +1,21 @@
import { cn } from '~/utils';
export default function ContinueIcon({ className = '' }: { className?: string }) {
return (
<svg
stroke="currentColor"
fill="none"
strokeWidth="2"
viewBox="0 0 24 24"
strokeLinecap="round"
strokeLinejoin="round"
className={cn('h-3 w-3 -rotate-180 text-gray-600 dark:text-gray-400', className)}
height="1em"
width="1em"
xmlns="http://www.w3.org/2000/svg"
>
<polygon points="11 19 2 12 11 5 11 19" />
<polygon points="22 19 13 12 22 5 22 19" />
</svg>
);
}

View file

@ -1,6 +1,6 @@
import React from 'react';
import { cn } from '~/utils';
export default function Regenerate() {
export default function RegenerateIcon({ className = '' }: { className?: string }) {
return (
<svg
stroke="currentColor"
@ -9,7 +9,7 @@ export default function Regenerate() {
viewBox="0 0 24 24"
strokeLinecap="round"
strokeLinejoin="round"
className="h-4 w-4"
className={cn('h-4 w-4 text-gray-600 dark:text-gray-400', className)}
height="1em"
width="1em"
xmlns="http://www.w3.org/2000/svg"

View file

@ -1,6 +1,6 @@
import React from 'react';
import { cn } from '~/utils';
export default function StopGeneratingIcon() {
export default function StopGeneratingIcon({ className = '' }: { className?: string }) {
return (
<svg
stroke="currentColor"
@ -9,7 +9,7 @@ export default function StopGeneratingIcon() {
viewBox="0 0 24 24"
strokeLinecap="round"
strokeLinejoin="round"
className="h-3 w-3"
className={cn('h-3 w-3 text-gray-600 dark:text-gray-400', className)}
height="1em"
width="1em"
xmlns="http://www.w3.org/2000/svg"

View file

@ -10,6 +10,8 @@ export { default as CrossIcon } from './CrossIcon';
export { default as LogOutIcon } from './LogOutIcon';
export { default as MessagesSquared } from './MessagesSquared';
export { default as StopGeneratingIcon } from './StopGeneratingIcon';
export { default as RegenerateIcon } from './RegenerateIcon';
export { default as ContinueIcon } from './ContinueIcon';
export { default as GoogleIcon } from './GoogleIcon';
export { default as OpenIDIcon } from './OpenIDIcon';
export { default as GithubIcon } from './GithubIcon';

View file

@ -40,6 +40,7 @@ function SelectDropDown({
{({ open }) => (
<>
<Listbox.Button
data-testid="select-dropdown-button"
className={cn(
'relative flex w-full cursor-default flex-col rounded-md border border-black/10 bg-white py-2 pl-3 pr-10 text-left focus:outline-none focus:ring-0 focus:ring-offset-0 dark:border-white/20 dark:bg-gray-800 sm:text-sm',
className ?? '',

View file

@ -6,4 +6,5 @@ export { default as useDebounce } from './useDebounce';
export { default as useLocalize } from './useLocalize';
export { default as useMediaQuery } from './useMediaQuery';
export { default as useSetOptions } from './useSetOptions';
export { default as useGenerations } from './useGenerations';
export { default as useMessageHandler } from './useMessageHandler';

View file

@ -0,0 +1,55 @@
import type { TMessage } from 'librechat-data-provider';
import { useRecoilValue } from 'recoil';
import store from '~/store';
type TUseGenerations = {
endpoint?: string;
message: TMessage;
isSubmitting: boolean;
isEditing?: boolean;
};
export default function useGenerations({
endpoint,
message,
isSubmitting,
isEditing = false,
}: TUseGenerations) {
const latestMessage = useRecoilValue(store.latestMessage);
const { error, messageId, searchResult, finish_reason, isCreatedByUser } = message ?? {};
const continueSupported =
latestMessage?.messageId === messageId &&
finish_reason &&
finish_reason !== 'stop' &&
!!['azureOpenAI', 'openAI', 'gptPlugins', 'anthropic'].find((e) => e === endpoint);
const branchingSupported =
// 5/21/23: Bing is allowing editing and Message regenerating
!![
'azureOpenAI',
'openAI',
'chatGPTBrowser',
'google',
'bingAI',
'gptPlugins',
'anthropic',
].find((e) => e === endpoint);
const editEnabled =
!error &&
isCreatedByUser && // TODO: allow AI editing
!searchResult &&
!isEditing &&
branchingSupported;
const regenerateEnabled =
!isCreatedByUser && !searchResult && !isEditing && !isSubmitting && branchingSupported;
return {
continueSupported,
editEnabled,
regenerateEnabled,
};
}

View file

@ -1,219 +0,0 @@
import { v4 } from 'uuid';
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
import store from '~/store';
const useMessageHandler = () => {
const currentConversation = useRecoilValue(store.conversation) || {};
const setSubmission = useSetRecoilState(store.submission);
const isSubmitting = useRecoilValue(store.isSubmitting);
const endpointsConfig = useRecoilValue(store.endpointsConfig);
const { getToken } = store.useToken(currentConversation?.endpoint);
const latestMessage = useRecoilValue(store.latestMessage);
const [messages, setMessages] = useRecoilState(store.messages);
const ask = (
{ text, parentMessageId = null, conversationId = null, messageId = null },
{ isRegenerate = false } = {},
) => {
if (!!isSubmitting || text === '') {
return;
}
// determine the model to be used
const { endpoint } = currentConversation;
let endpointOption = {};
let responseSender = '';
if (endpoint === 'azureOpenAI' || endpoint === 'openAI') {
endpointOption = {
endpoint,
model:
currentConversation?.model ??
endpointsConfig[endpoint]?.availableModels?.[0] ??
'gpt-3.5-turbo',
chatGptLabel: currentConversation?.chatGptLabel ?? null,
promptPrefix: currentConversation?.promptPrefix ?? null,
temperature: currentConversation?.temperature ?? 1,
top_p: currentConversation?.top_p ?? 1,
presence_penalty: currentConversation?.presence_penalty ?? 0,
frequency_penalty: currentConversation?.frequency_penalty ?? 0,
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
};
responseSender = endpointOption.chatGptLabel ?? 'ChatGPT';
} else if (endpoint === 'google') {
endpointOption = {
endpoint,
model:
currentConversation?.model ??
endpointsConfig[endpoint]?.availableModels?.[0] ??
'chat-bison',
modelLabel: currentConversation?.modelLabel ?? null,
promptPrefix: currentConversation?.promptPrefix ?? null,
examples: currentConversation?.examples ?? [
{ input: { content: '' }, output: { content: '' } },
],
temperature: currentConversation?.temperature ?? 0.2,
maxOutputTokens: currentConversation?.maxOutputTokens ?? 1024,
topP: currentConversation?.topP ?? 0.95,
topK: currentConversation?.topK ?? 40,
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
};
responseSender = endpointOption.chatGptLabel ?? 'ChatGPT';
} else if (endpoint === 'bingAI') {
endpointOption = {
endpoint,
jailbreak: currentConversation?.jailbreak ?? false,
systemMessage: currentConversation?.systemMessage ?? null,
context: currentConversation?.context ?? null,
toneStyle: currentConversation?.toneStyle ?? 'creative',
jailbreakConversationId: currentConversation?.jailbreakConversationId ?? null,
conversationSignature: currentConversation?.conversationSignature ?? null,
clientId: currentConversation?.clientId ?? null,
invocationId: currentConversation?.invocationId ?? 1,
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
};
responseSender = endpointOption.jailbreak ? 'Sydney' : 'BingAI';
} else if (endpoint === 'anthropic') {
endpointOption = {
endpoint,
model:
currentConversation?.model ??
endpointsConfig[endpoint]?.availableModels?.[0] ??
'claude-1',
modelLabel: currentConversation?.modelLabel ?? null,
promptPrefix: currentConversation?.promptPrefix ?? null,
temperature: currentConversation?.temperature ?? 1,
maxOutputTokens: currentConversation?.maxOutputTokens ?? 1024,
topP: currentConversation?.topP ?? 0.7,
topK: currentConversation?.topK ?? 5,
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
};
responseSender = 'Anthropic';
} else if (endpoint === 'chatGPTBrowser') {
endpointOption = {
endpoint,
model:
currentConversation?.model ??
endpointsConfig[endpoint]?.availableModels?.[0] ??
'text-davinci-002-render-sha',
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
};
responseSender = 'ChatGPT';
} else if (endpoint === 'gptPlugins') {
const agentOptions = currentConversation?.agentOptions ?? {
agent: 'functions',
skipCompletion: true,
model: 'gpt-3.5-turbo',
temperature: 0,
};
endpointOption = {
endpoint,
tools: currentConversation?.tools ?? [],
model:
currentConversation?.model ??
endpointsConfig[endpoint]?.availableModels?.[0] ??
'gpt-3.5-turbo',
chatGptLabel: currentConversation?.chatGptLabel ?? null,
promptPrefix: currentConversation?.promptPrefix ?? null,
temperature: currentConversation?.temperature ?? 0.8,
top_p: currentConversation?.top_p ?? 1,
presence_penalty: currentConversation?.presence_penalty ?? 0,
frequency_penalty: currentConversation?.frequency_penalty ?? 0,
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
agentOptions,
};
responseSender = 'ChatGPT';
} else if (endpoint === null) {
console.error('No endpoint available');
return;
} else {
console.error(`Unknown endpoint ${endpoint}`);
return;
}
let currentMessages = messages;
// construct the query message
// this is not a real messageId, it is used as placeholder before real messageId returned
text = text.trim();
const fakeMessageId = v4();
parentMessageId =
parentMessageId || latestMessage?.messageId || '00000000-0000-0000-0000-000000000000';
conversationId = conversationId || currentConversation?.conversationId;
if (conversationId == 'search') {
console.error('cannot send any message under search view!');
return;
}
if (conversationId == 'new') {
parentMessageId = '00000000-0000-0000-0000-000000000000';
currentMessages = [];
conversationId = null;
}
const currentMsg = {
sender: 'User',
text,
current: true,
isCreatedByUser: true,
parentMessageId,
conversationId,
messageId: fakeMessageId,
};
// construct the placeholder response message
const initialResponse = {
sender: responseSender,
text: '<span className="result-streaming">█</span>',
parentMessageId: isRegenerate ? messageId : fakeMessageId,
messageId: (isRegenerate ? messageId : fakeMessageId) + '_',
conversationId,
unfinished: false,
submitting: true,
};
const submission = {
conversation: {
...currentConversation,
conversationId,
},
endpointOption,
message: {
...currentMsg,
overrideParentMessageId: isRegenerate ? messageId : null,
},
messages: currentMessages,
isRegenerate,
initialResponse,
};
console.log('User Input:', text, submission);
if (isRegenerate) {
setMessages([...currentMessages, initialResponse]);
} else {
setMessages([...currentMessages, currentMsg, initialResponse]);
}
setSubmission(submission);
};
const regenerate = ({ parentMessageId }) => {
const parentMessage = messages?.find((element) => element.messageId == parentMessageId);
if (parentMessage && parentMessage.isCreatedByUser) {
ask({ ...parentMessage }, { isRegenerate: true });
} else {
console.error(
'Failed to regenerate the message: parentMessage not found or not created by user.',
);
}
};
const stopGenerating = () => {
setSubmission(null);
};
return { ask, regenerate, stopGenerating };
};
export default useMessageHandler;

View file

@ -0,0 +1,201 @@
import { v4 } from 'uuid';
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
import { parseConvo, getResponseSender } from 'librechat-data-provider';
import type { TMessage } from 'librechat-data-provider';
import store from '~/store';
type TAskProps = {
text: string;
parentMessageId?: string | null;
conversationId?: string | null;
messageId?: string | null;
};
const useMessageHandler = () => {
const currentConversation = useRecoilValue(store.conversation) || { endpoint: null };
const setSubmission = useSetRecoilState(store.submission);
const isSubmitting = useRecoilValue(store.isSubmitting);
const endpointsConfig = useRecoilValue(store.endpointsConfig);
const latestMessage = useRecoilValue(store.latestMessage);
const [messages, setMessages] = useRecoilState(store.messages);
const { endpoint } = currentConversation;
const { getToken } = store.useToken(endpoint ?? '');
const ask = (
{ text, parentMessageId = null, conversationId = null, messageId = null }: TAskProps,
{ isRegenerate = false, isEdited = false } = {},
) => {
if (!!isSubmitting || text === '') {
return;
}
if (endpoint === null) {
console.error('No endpoint available');
return;
}
conversationId = conversationId ?? currentConversation?.conversationId;
if (conversationId == 'search') {
console.error('cannot send any message under search view!');
return;
}
if (isEdited && !latestMessage) {
console.error('cannot edit AI message without latestMessage!');
return;
}
const { userProvide } = endpointsConfig[endpoint] ?? {};
// set the endpoint option
const convo = parseConvo(endpoint, currentConversation);
const endpointOption = {
endpoint,
...convo,
token: userProvide ? getToken() : null,
};
const responseSender = getResponseSender(endpointOption);
let currentMessages: TMessage[] | null = messages ?? [];
// construct the query message
// this is not a real messageId, it is used as placeholder before real messageId returned
text = text.trim();
const fakeMessageId = v4();
parentMessageId =
parentMessageId || latestMessage?.messageId || '00000000-0000-0000-0000-000000000000';
if (conversationId == 'new') {
parentMessageId = '00000000-0000-0000-0000-000000000000';
currentMessages = [];
conversationId = null;
}
const currentMsg: TMessage = {
sender: 'User',
text,
current: true,
isCreatedByUser: true,
parentMessageId,
conversationId,
messageId: isEdited && messageId ? messageId : fakeMessageId,
error: false,
};
// construct the placeholder response message
const generation = latestMessage?.text ?? '';
const responseText = isEdited ? generation : '<span className="result-streaming">█</span>';
const responseMessageId = isEdited ? latestMessage?.messageId : null;
const initialResponse: TMessage = {
sender: responseSender,
text: responseText,
parentMessageId: isRegenerate ? messageId : fakeMessageId,
messageId: responseMessageId ?? `${isRegenerate ? messageId : fakeMessageId}_`,
conversationId,
unfinished: false,
submitting: true,
isCreatedByUser: false,
error: false,
};
const submission = {
conversation: {
...currentConversation,
conversationId,
},
endpointOption,
message: {
...currentMsg,
generation,
responseMessageId,
overrideParentMessageId: isRegenerate ? messageId : null,
},
messages: currentMessages,
isEdited,
isRegenerate,
initialResponse,
};
console.log('User Input:', text, submission);
if (isRegenerate) {
setMessages([
...(isEdited ? currentMessages.slice(0, -1) : currentMessages),
initialResponse,
]);
} else {
setMessages([...currentMessages, currentMsg, initialResponse]);
}
setSubmission(submission);
};
const regenerate = ({ parentMessageId }) => {
const parentMessage = messages?.find((element) => element.messageId == parentMessageId);
if (parentMessage && parentMessage.isCreatedByUser) {
ask({ ...parentMessage }, { isRegenerate: true });
} else {
console.error(
'Failed to regenerate the message: parentMessage not found or not created by user.',
);
}
};
const continueGeneration = () => {
if (!latestMessage) {
console.error('Failed to regenerate the message: latestMessage not found.');
return;
}
const parentMessage = messages?.find(
(element) => element.messageId == latestMessage.parentMessageId,
);
if (parentMessage && parentMessage.isCreatedByUser) {
ask({ ...parentMessage }, { isRegenerate: true, isEdited: true });
} else {
console.error(
'Failed to regenerate the message: parentMessage not found, or not created by user.',
);
}
};
const stopGenerating = () => {
setSubmission(null);
};
const handleStopGenerating = (e: React.MouseEvent<HTMLButtonElement>) => {
e.preventDefault();
stopGenerating();
};
const handleRegenerate = (e: React.MouseEvent<HTMLButtonElement>) => {
e.preventDefault();
const parentMessageId = latestMessage?.parentMessageId;
if (!parentMessageId) {
console.error('Failed to regenerate the message: parentMessageId not found.');
return;
}
regenerate({ parentMessageId });
};
const handleContinue = (e: React.MouseEvent<HTMLButtonElement>) => {
e.preventDefault();
continueGeneration();
};
return {
ask,
regenerate,
stopGenerating,
handleStopGenerating,
handleRegenerate,
handleContinue,
endpointsConfig,
latestMessage,
isSubmitting,
messages,
};
};
export default useMessageHandler;

View file

@ -1,8 +1,8 @@
import { useEffect } from 'react';
import { useRecoilValue, useResetRecoilState, useSetRecoilState } from 'recoil';
import { SSE, createPayload } from 'librechat-data-provider';
import store from '~/store';
import { useAuthContext } from '~/hooks/AuthContext';
import store from '~/store';
export default function MessageHandler() {
const submission = useRecoilValue(store.submission);
@ -15,11 +15,18 @@ export default function MessageHandler() {
const { refreshConversations } = store.useConversations();
const messageHandler = (data, submission) => {
const { messages, message, plugin, initialResponse, isRegenerate = false } = submission;
const {
messages,
message,
plugin,
initialResponse,
isRegenerate = false,
isEdited = false,
} = submission;
if (isRegenerate) {
setMessages([
...messages,
...(isEdited ? messages.slice(0, -1) : messages),
{
...initialResponse,
text: data,
@ -48,13 +55,12 @@ export default function MessageHandler() {
};
const cancelHandler = (data, submission) => {
const { messages, isRegenerate = false } = submission;
const { requestMessage, responseMessage, conversation } = data;
const { messages, isRegenerate = false, isEdited = false } = submission;
// update the messages
if (isRegenerate) {
setMessages([...messages, responseMessage]);
setMessages([...(isEdited ? messages.slice(0, -1) : messages), responseMessage]);
} else {
setMessages([...messages, requestMessage, responseMessage]);
}
@ -79,11 +85,17 @@ export default function MessageHandler() {
};
const createdHandler = (data, submission) => {
const { messages, message, initialResponse, isRegenerate = false } = submission;
const {
messages,
message,
initialResponse,
isRegenerate = false,
isEdited = false,
} = submission;
if (isRegenerate) {
setMessages([
...messages,
...(isEdited ? messages.slice(0, -1) : messages),
{
...initialResponse,
parentMessageId: message?.overrideParentMessageId,
@ -113,13 +125,12 @@ export default function MessageHandler() {
};
const finalHandler = (data, submission) => {
const { messages, isRegenerate = false } = submission;
const { requestMessage, responseMessage, conversation } = data;
const { messages, isRegenerate = false, isEdited = false } = submission;
// update the messages
if (isRegenerate) {
setMessages([...messages, responseMessage]);
setMessages([...(isEdited ? messages.slice(0, -1) : messages), responseMessage]);
} else {
setMessages([...messages, requestMessage, responseMessage]);
}

View file

@ -1,16 +1,16 @@
/* eslint-disable react-hooks/exhaustive-deps */
import { useEffect, useState } from 'react';
import { useSetRecoilState } from 'recoil';
import { Outlet } from 'react-router-dom';
import {
useGetEndpointsQuery,
useGetPresetsQuery,
useGetSearchEnabledQuery,
} from 'librechat-data-provider';
import MessageHandler from '../components/MessageHandler';
import { Nav, MobileNav } from '../components/Nav';
import { Outlet } from 'react-router-dom';
import { Nav, MobileNav } from '~/components/Nav';
import { useAuthContext } from '~/hooks/AuthContext';
import { useSetRecoilState } from 'recoil';
import MessageHandler from './MessageHandler';
import store from '~/store';
export default function Root() {

View file

@ -7,7 +7,13 @@ import {
useResetRecoilState,
useRecoilCallback,
} from 'recoil';
import { TConversation, TMessagesAtom, TSubmission, TPreset } from 'librechat-data-provider';
import {
TConversation,
TMessagesAtom,
TMessage,
TSubmission,
TPreset,
} from 'librechat-data-provider';
import { buildTree, getDefaultConversation } from '~/utils';
import submission from './submission';
import endpoints from './endpoints';
@ -32,7 +38,7 @@ const messagesTree = selector({
},
});
const latestMessage = atom({
const latestMessage = atom<TMessage | null>({
key: 'latestMessage',
default: null,
});

View file

@ -21,7 +21,7 @@ export default function buildTree(messages: TMessage[] | null, groupAll = false)
messages.forEach((message) => {
messageMap[message.messageId] = { ...message, children: [] };
const parentMessage = messageMap[message.parentMessageId];
const parentMessage = messageMap[message.parentMessageId ?? ''];
if (parentMessage) {
parentMessage.children.push(messageMap[message.messageId]);
} else {

Some files were not shown because too many files have changed in this diff Show more