mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-16 16:30:15 +01:00
feat(GPT/Anthropic): Continue Regenerating & Generation Buttons (#808)
* feat(useMessageHandler.js/ts): Refactor and add features to handle user messages, support multiple endpoints/models, generate placeholder responses, regeneration, and stopGeneration function
fix(conversation.ts, buildTree.ts): Import TMessage type, handle null parentMessageId
feat(schemas.ts): Update and add schemas for various AI services, add default values, optional fields, and endpoint-to-schema mapping, create parseConvo function
chore(useMessageHandler.js, schemas.ts): Remove unused imports, variables, and chatGPT enum
* wip: add generation buttons
* refactor(cleanupPreset.ts): simplify cleanupPreset function
refactor(getDefaultConversation.js): remove unused code and simplify getDefaultConversation function
feat(utils): add getDefaultConversation function
This commit adds a new utility function called `getDefaultConversation` to the `client/src/utils/getDefaultConversation.ts` file. This function is responsible for generating a default conversation object based on the provided parameters.
The `getDefaultConversation` function takes in an object with the following properties:
- `conversation`: The conversation object to be used as a base.
- `endpointsConfig`: The configuration object containing information about the available endpoints.
- `preset`: An optional preset object that can be used to override the default behavior.
The function first tries to determine the target endpoint based on the preset object. If a valid endpoint is found, it is used as the target endpoint. If not, the function tries to retrieve the last conversation setup from the local storage and uses its endpoint if it is valid. If neither the preset nor the local storage contains a valid endpoint, the function falls back to a default endpoint.
Once the target endpoint is determined,
* fix(utils): remove console.error statement in buildDefaultConversation function
fix(schemas): add default values for catch blocks in openAISchema, googleSchema, bingAISchema, anthropicSchema, chatGPTBrowserSchema, and gptPluginsSchema
* fix: endpoint not changing on change of preset from other endpoint, wip: refactor
* refactor: preset items to TSX
* refactor: convert resetConvo to TS
* refactor(getDefaultConversation.ts): move defaultEndpoints array to the top of the file for better readability
refactor(getDefaultConversation.ts): extract getDefaultEndpoint function for better code organization and reusability
* feat(svg): add ContinueIcon component
feat(svg): add RegenerateIcon component
feat(svg): add ContinueIcon and RegenerateIcon components to index.ts
* feat(Button.tsx): add onClick and className props to Button component
feat(GenerationButtons.tsx): add logic to display Regenerate or StopGenerating button based on isSubmitting and messages
feat(Regenerate.tsx): create Regenerate component with RegenerateIcon and handleRegenerate function
feat(StopGenerating.tsx): create StopGenerating component with StopGeneratingIcon and handleStopGenerating function
* fix(TextChat.jsx): reorder imports and variables for better readability
fix(TextChat.jsx): fix typo in condition for isNotAppendable variable
fix(TextChat.jsx): remove unused handleStopGenerating function
fix(ContinueIcon.tsx): remove unnecessary closing tags for polygon elements
fix(useMessageHandler.ts): add missing type annotations for handleStopGenerating and handleRegenerate functions
fix(useMessageHandler.ts): remove unused variables in return statement
* fix(getDefaultConversation.ts): refactor code to use getLocalStorageItems function
feat(getLocalStorageItems.ts): add utility function to retrieve items from local storage
* fix(OpenAIClient.js): add support for streaming result in sendCompletion method
feat(OpenAIClient.js): add finish_reason metadata to opts in sendCompletion method
feat(Message.js): add finish_reason field to Message model
feat(messageSchema.js): add finish_reason field to messageSchema
feat(openAI.js): parse chatGptLabel and promptPrefix from req.body and pass rest of the modelOptions to endpointOption
feat(openAI.js): add addMetadata function to store metadata in ask function
feat(openAI.js): add metadata to response if available
feat(schemas.ts): add finish_reason field to tMessageSchema
* feat(types.ts): add TOnClick and TGenButtonProps types for button components
feat(Continue.tsx): create Continue component for generating button
feat(GenerationButtons.tsx): update GenerationButtons component to use Continue component
feat(Regenerate.tsx): create Regenerate component for regenerating button
feat(Stop.tsx): create Stop component for stop generating button
* feat(MessageHandler.jsx): add MessageHandler component to handle messages and conversations
fix(Root.jsx): fix import paths for Nav and MessageHandler components
* feat(useMessageHandler.ts): add support for generation parameter in ask function
feat(useMessageHandler.ts): add support for isEdited parameter in ask function
feat(useMessageHandler.ts): add support for continueGeneration function
fix(createPayload.ts): replace endpoint URL when isEdited parameter is true
* chore(client): set skipLibCheck to true in tsconfig.json
* fix(useMessageHandler.ts): remove unused clientId variable
fix(schemas.ts): make clientId field in tMessageSchema nullable and optional
* wip: edit route for continue generation
* refactor(api): move handlers to root of routes dir
* fix(useMessageHandler.ts): initialize currentMessages to an empty array if messages is null
fix(useMessageHandler.ts): update initialResponse text to use responseText variable
fix(useMessageHandler.ts): update setMessages logic for isRegenerate case
fix(MessageHandler.jsx): update setMessages logic for cancelHandler, createdHandler, and finalHandler
* fix(schemas.ts): make createdAt and updatedAt fields optional and set default values using new Date().toISOString()
fix(schemas.ts): change type annotation of TMessage from infer to input
* refactor(useMessageHandler.ts): rename AskProps type to TAskProps
refactor(useMessageHandler.ts): remove generation property from ask function arguments
refactor(useMessageHandler.ts): use nullish coalescing operator (??) instead of logical OR (||)
refactor(useMessageHandler.ts): pass the responseMessageId to message prop of submission
* fix(BaseClient.js): use nullish coalescing operator (??) instead of logical OR (||) for default values
* fix(BaseClient.js): fix responseMessageId assignment in handleStartMethods method
feat(BaseClient.js): add support for isEdited flag in sendMessage method
feat(BaseClient.js): add generation to responseMessage text in sendMessage method
* fix(openAI.js): remove unused imports and commented out code
feat(openAI.js): add support for generation parameter in request body
fix(openAI.js): remove console.log statement
fix(openAI.js): remove unused variables and parameters
fix(openAI.js): update response text in case of error
fix(openAI.js): handle error and abort message in case of error
fix(handlers.js): add generation parameter to createOnProgress function
fix(useMessageHandler.ts): update responseText variable to use generation parameter
* refactor(api/middleware): move inside server dir
* refactor: add endpoint specific, modular functions to build options and initialize clients, create server/utils, move middleware, separate utils into api general utils and server specific utils
* fix(abortMiddleware.js): import getConvo and getConvoTitle functions from models
feat(abortMiddleware.js): add abortAsk function to abortController to handle aborting of requests
fix(openAI.js): import buildOptions and initializeClient functions from endpoints/openAI
refactor(openAI.js): use getAbortData function to get data for abortAsk function
* refactor: move endpoint specific logic to an endpoints dir
* refactor(PluginService.js): fix import path for encrypt and decrypt functions in PluginService.js
* feat(openAI): add new endpoint for adding a title to a conversation
- Added a new file `addTitle.js` in the `api/server/routes/endpoints/openAI` directory.
- The `addTitle.js` file exports a function `addTitle` that takes in request parameters and performs the following actions:
- If the `parentMessageId` is `'00000000-0000-0000-0000-000000000000'` and `newConvo` is true, it proceeds with the following steps:
- Calls the `titleConvo` function from the `titleConvo` module, passing in the necessary parameters.
- Calls the `saveConvo` function from the `saveConvo` module, passing in the user ID and conversation details.
- Updated the `index.js` file in the `api/server/routes/endpoints/openAI` directory to export the `addTitle` function.
- This change adds
* fix(abortMiddleware.js): remove console.log statement
refactor(gptPlugins.js): update imports and function parameters
feat(gptPlugins.js): add support for abortController and getAbortData
refactor(openAI.js): update imports and function parameters
feat(openAI.js): add support for abortController and getAbortData
fix(openAI.js): refactor code to use modularized functions and middleware
fix(buildOptions.js): refactor code to use destructuring and update variable names
* refactor(askChatGPTBrowser.js, bingAI.js, google.js): remove duplicate code for setting response headers
feat(askChatGPTBrowser.js, bingAI.js, google.js): add setHeaders middleware to set response headers
* feat(middleware): validateEndpoint, refactor buildOption to only be concerned of endpointOption
* fix(abortMiddleware.js): add 'finish_reason' property with value 'incomplete' to responseMessage object
fix(abortMessage.js): remove console.log statement for aborted message
fix(handlers.js): modify tokens assignment to handle empty generation string and trailing space
* fix(BaseClient.js): import addSpaceIfNeeded function from server/utils
fix(BaseClient.js): add space before generation in text property
fix(index.js): remove getCitations and citeText exports
feat(buildEndpointOption.js): add buildEndpointOption middleware
fix(index.js): import buildEndpointOption middleware
fix(anthropic.js): remove buildOptions function and use endpointOption from req.body
fix(gptPlugins.js): remove buildOptions function and use endpointOption from req.body
fix(openAI.js): remove buildOptions function and use endpointOption from req.body
feat(utils): add citations.js and handleText.js modules
fix(utils): fix import statements in index.js module
* refactor(gptPlugins.js): use getResponseSender function from librechat-data-provider
* feat(gptPlugins): complete 'continue generating'
* wip: anthropic continue regen
* feat(middleware): add validateRegistration middleware
A new middleware function called `validateRegistration` has been added to the list of exported middleware functions in `index.js`. This middleware is responsible for validating registration data before allowing the registration process to proceed.
* feat(Anthropic): complete continue regen
* chore: add librechat-data-provider to api/package.json
* fix(ci): backend-review will mock meilisearch, also installs data-provider as now needed
* chore(ci): remove unneeded SEARCH env var
* style(GenerationButtons): make text shorter for sake of space economy, even though this diverges from chat.openai.com
* style(GenerationButtons/ScrollToBottom): adjust visibility/position based on screen size
* chore(client): 'Editting' typo
* feat(GenerationButtons.tsx): add support for endpoint prop in GenerationButtons component
feat(OptionsBar.tsx): pass endpoint prop to GenerationButtons component
feat(useGenerations.ts): create useGenerations hook to handle generation logic
fix(schemas.ts): add searchResult field to tMessageSchema
* refactor(HoverButtons): convert to TSX and utilize new useGenerations hook
* fix(abortMiddleware): handle error with res headers set, or abortController not found, to ensure proper API error is sent to the client, chore(BaseClient): remove console log for onStart message meant for debugging
* refactor(api): remove librechat-data-provider dep for now as it complicates deployed docker build stage, re-use code in CJS, located in server/endpoints/schemas
* chore: remove console.logs from test files
* ci: add backend tests for AnthropicClient, focusing on new buildMessages logic
* refactor(FakeClient): use actual BaseClient sendMessage method for testing
* test(BaseClient.test.js): add test for loading chat history
test(BaseClient.test.js): add test for sendMessage logic with isEdited flag
* fix(buildEndpointOption.js): add support for azureOpenAI in buildFunction object
wip(endpoints.js): fetch Azure models from Azure OpenAI API if opts.azure is true
* fix(Button.tsx): add data-testid attribute to button component
fix(SelectDropDown.tsx): add data-testid attribute to Listbox.Button component
fix(messages.spec.ts): add waitForServerStream function to consolidate logic for awaiting the server response
feat(messages.spec.ts): add test for stopping and continuing message and improve browser/page context order and closing
* refactor(onProgress): speed up time to save initial message for editable routes
* chore: disable AI message editing (for now), was accidentally allowed
* refactor: ensure continue is only supported for latest message style: improve styling in dark mode and across all hover buttons/icons, including making edit icon for AI invisible (for now)
* fix: add test id to generation buttons so they never resolve to 2+ items
* chore(package.json): add 'packages/' to the list of ignored directories
chore(data-provider/package.json): bump version to 0.1.5
This commit is contained in:
parent
ae5b7d3d53
commit
afd43afb60
113 changed files with 3023 additions and 1543 deletions
10
.github/workflows/backend-review.yml
vendored
10
.github/workflows/backend-review.yml
vendored
|
|
@ -1,10 +1,5 @@
|
||||||
name: Backend Unit Tests
|
name: Backend Unit Tests
|
||||||
on:
|
on:
|
||||||
# push:
|
|
||||||
# branches:
|
|
||||||
# - main
|
|
||||||
# - dev
|
|
||||||
# - release/*
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
|
@ -23,6 +18,7 @@ jobs:
|
||||||
JWT_SECRET: ${{ secrets.JWT_SECRET }}
|
JWT_SECRET: ${{ secrets.JWT_SECRET }}
|
||||||
CREDS_KEY: ${{ secrets.CREDS_KEY }}
|
CREDS_KEY: ${{ secrets.CREDS_KEY }}
|
||||||
CREDS_IV: ${{ secrets.CREDS_IV }}
|
CREDS_IV: ${{ secrets.CREDS_IV }}
|
||||||
|
NODE_ENV: ci
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Use Node.js 20.x
|
- name: Use Node.js 20.x
|
||||||
|
|
@ -34,8 +30,8 @@ jobs:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: npm ci
|
run: npm ci
|
||||||
|
|
||||||
# - name: Install Linux X64 Sharp
|
- name: Install Data Provider
|
||||||
# run: npm install --platform=linux --arch=x64 --verbose sharp
|
run: npm run build:data-provider
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: cd api && npm run test:ci
|
run: cd api && npm run test:ci
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
const Keyv = require('keyv');
|
|
||||||
// const { Agent, ProxyAgent } = require('undici');
|
// const { Agent, ProxyAgent } = require('undici');
|
||||||
const BaseClient = require('./BaseClient');
|
const BaseClient = require('./BaseClient');
|
||||||
const {
|
const {
|
||||||
|
|
@ -15,8 +14,6 @@ const tokenizersCache = {};
|
||||||
class AnthropicClient extends BaseClient {
|
class AnthropicClient extends BaseClient {
|
||||||
constructor(apiKey, options = {}, cacheOptions = {}) {
|
constructor(apiKey, options = {}, cacheOptions = {}) {
|
||||||
super(apiKey, options, cacheOptions);
|
super(apiKey, options, cacheOptions);
|
||||||
cacheOptions.namespace = cacheOptions.namespace || 'anthropic';
|
|
||||||
this.conversationsCache = new Keyv(cacheOptions);
|
|
||||||
this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY;
|
this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY;
|
||||||
this.sender = 'Anthropic';
|
this.sender = 'Anthropic';
|
||||||
this.userLabel = HUMAN_PROMPT;
|
this.userLabel = HUMAN_PROMPT;
|
||||||
|
|
@ -107,6 +104,23 @@ class AnthropicClient extends BaseClient {
|
||||||
content: message?.content ?? message.text,
|
content: message?.content ?? message.text,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
let lastAuthor = '';
|
||||||
|
let groupedMessages = [];
|
||||||
|
|
||||||
|
for (let message of formattedMessages) {
|
||||||
|
// If last author is not same as current author, add to new group
|
||||||
|
if (lastAuthor !== message.author) {
|
||||||
|
groupedMessages.push({
|
||||||
|
author: message.author,
|
||||||
|
content: [message.content],
|
||||||
|
});
|
||||||
|
lastAuthor = message.author;
|
||||||
|
// If same author, append content to the last group
|
||||||
|
} else {
|
||||||
|
groupedMessages[groupedMessages.length - 1].content.push(message.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let identityPrefix = '';
|
let identityPrefix = '';
|
||||||
if (this.options.userLabel) {
|
if (this.options.userLabel) {
|
||||||
identityPrefix = `\nHuman's name: ${this.options.userLabel}`;
|
identityPrefix = `\nHuman's name: ${this.options.userLabel}`;
|
||||||
|
|
@ -129,8 +143,12 @@ class AnthropicClient extends BaseClient {
|
||||||
promptPrefix = `${identityPrefix}${promptPrefix}`;
|
promptPrefix = `${identityPrefix}${promptPrefix}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
const promptSuffix = `${promptPrefix}${this.assistantLabel}\n`; // Prompt AI to respond.
|
// Prompt AI to respond, empty if last message was from AI
|
||||||
let currentTokenCount = this.getTokenCount(promptSuffix);
|
let isEdited = lastAuthor === this.assistantLabel;
|
||||||
|
const promptSuffix = isEdited ? '' : `${promptPrefix}${this.assistantLabel}\n`;
|
||||||
|
let currentTokenCount = isEdited
|
||||||
|
? this.getTokenCount(promptPrefix)
|
||||||
|
: this.getTokenCount(promptSuffix);
|
||||||
|
|
||||||
let promptBody = '';
|
let promptBody = '';
|
||||||
const maxTokenCount = this.maxPromptTokens;
|
const maxTokenCount = this.maxPromptTokens;
|
||||||
|
|
@ -148,10 +166,13 @@ class AnthropicClient extends BaseClient {
|
||||||
};
|
};
|
||||||
|
|
||||||
const buildPromptBody = async () => {
|
const buildPromptBody = async () => {
|
||||||
if (currentTokenCount < maxTokenCount && formattedMessages.length > 0) {
|
if (currentTokenCount < maxTokenCount && groupedMessages.length > 0) {
|
||||||
const message = formattedMessages.pop();
|
const message = groupedMessages.pop();
|
||||||
const isCreatedByUser = message.author === this.userLabel;
|
const isCreatedByUser = message.author === this.userLabel;
|
||||||
const messageString = `${message.author}\n${message.content}${this.endToken}\n`;
|
// Use promptPrefix if message is edited assistant'
|
||||||
|
const messagePrefix =
|
||||||
|
isCreatedByUser || !isEdited ? message.author : `${promptPrefix}${message.author}`;
|
||||||
|
const messageString = `${messagePrefix}\n${message.content}${this.endToken}\n`;
|
||||||
let newPromptBody = `${messageString}${promptBody}`;
|
let newPromptBody = `${messageString}${promptBody}`;
|
||||||
|
|
||||||
context.unshift(message);
|
context.unshift(message);
|
||||||
|
|
@ -182,6 +203,12 @@ class AnthropicClient extends BaseClient {
|
||||||
}
|
}
|
||||||
promptBody = newPromptBody;
|
promptBody = newPromptBody;
|
||||||
currentTokenCount = newTokenCount;
|
currentTokenCount = newTokenCount;
|
||||||
|
|
||||||
|
// Switch off isEdited after using it for the first time
|
||||||
|
if (isEdited) {
|
||||||
|
isEdited = false;
|
||||||
|
}
|
||||||
|
|
||||||
// wait for next tick to avoid blocking the event loop
|
// wait for next tick to avoid blocking the event loop
|
||||||
await new Promise((resolve) => setImmediate(resolve));
|
await new Promise((resolve) => setImmediate(resolve));
|
||||||
return buildPromptBody();
|
return buildPromptBody();
|
||||||
|
|
@ -197,7 +224,8 @@ class AnthropicClient extends BaseClient {
|
||||||
context.shift();
|
context.shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
const prompt = `${promptBody}${promptSuffix}`;
|
let prompt = `${promptBody}${promptSuffix}`;
|
||||||
|
|
||||||
// Add 2 tokens for metadata after all messages have been counted.
|
// Add 2 tokens for metadata after all messages have been counted.
|
||||||
currentTokenCount += 2;
|
currentTokenCount += 2;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,12 @@ const { ChatOpenAI } = require('langchain/chat_models/openai');
|
||||||
const { loadSummarizationChain } = require('langchain/chains');
|
const { loadSummarizationChain } = require('langchain/chains');
|
||||||
const { refinePrompt } = require('./prompts/refinePrompt');
|
const { refinePrompt } = require('./prompts/refinePrompt');
|
||||||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models');
|
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models');
|
||||||
|
const { addSpaceIfNeeded } = require('../../server/utils');
|
||||||
|
|
||||||
class BaseClient {
|
class BaseClient {
|
||||||
constructor(apiKey, options = {}) {
|
constructor(apiKey, options = {}) {
|
||||||
this.apiKey = apiKey;
|
this.apiKey = apiKey;
|
||||||
this.sender = options.sender || 'AI';
|
this.sender = options.sender ?? 'AI';
|
||||||
this.contextStrategy = null;
|
this.contextStrategy = null;
|
||||||
this.currentDateString = new Date().toLocaleDateString('en-us', {
|
this.currentDateString = new Date().toLocaleDateString('en-us', {
|
||||||
year: 'numeric',
|
year: 'numeric',
|
||||||
|
|
@ -51,18 +52,20 @@ class BaseClient {
|
||||||
if (opts && typeof opts === 'object') {
|
if (opts && typeof opts === 'object') {
|
||||||
this.setOptions(opts);
|
this.setOptions(opts);
|
||||||
}
|
}
|
||||||
const user = opts.user || null;
|
const user = opts.user ?? null;
|
||||||
const conversationId = opts.conversationId || crypto.randomUUID();
|
const conversationId = opts.conversationId ?? crypto.randomUUID();
|
||||||
const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000';
|
const parentMessageId = opts.parentMessageId ?? '00000000-0000-0000-0000-000000000000';
|
||||||
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
|
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||||
const responseMessageId = crypto.randomUUID();
|
const responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
|
||||||
const saveOptions = this.getSaveOptions();
|
const saveOptions = this.getSaveOptions();
|
||||||
this.abortController = opts.abortController || new AbortController();
|
const head = opts.isEdited ? responseMessageId : parentMessageId;
|
||||||
this.currentMessages = (await this.loadHistory(conversationId, parentMessageId)) ?? [];
|
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
|
||||||
|
this.abortController = opts.abortController ?? new AbortController();
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...opts,
|
...opts,
|
||||||
user,
|
user,
|
||||||
|
head,
|
||||||
conversationId,
|
conversationId,
|
||||||
parentMessageId,
|
parentMessageId,
|
||||||
userMessageId,
|
userMessageId,
|
||||||
|
|
@ -72,7 +75,7 @@ class BaseClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
createUserMessage({ messageId, parentMessageId, conversationId, text }) {
|
createUserMessage({ messageId, parentMessageId, conversationId, text }) {
|
||||||
const userMessage = {
|
return {
|
||||||
messageId,
|
messageId,
|
||||||
parentMessageId,
|
parentMessageId,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
|
@ -80,19 +83,27 @@ class BaseClient {
|
||||||
text,
|
text,
|
||||||
isCreatedByUser: true,
|
isCreatedByUser: true,
|
||||||
};
|
};
|
||||||
return userMessage;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async handleStartMethods(message, opts) {
|
async handleStartMethods(message, opts) {
|
||||||
const { user, conversationId, parentMessageId, userMessageId, responseMessageId, saveOptions } =
|
const {
|
||||||
await this.setMessageOptions(opts);
|
user,
|
||||||
|
head,
|
||||||
const userMessage = this.createUserMessage({
|
|
||||||
messageId: userMessageId,
|
|
||||||
parentMessageId,
|
|
||||||
conversationId,
|
conversationId,
|
||||||
text: message,
|
parentMessageId,
|
||||||
});
|
userMessageId,
|
||||||
|
responseMessageId,
|
||||||
|
saveOptions,
|
||||||
|
} = await this.setMessageOptions(opts);
|
||||||
|
|
||||||
|
const userMessage = opts.isEdited
|
||||||
|
? this.currentMessages[this.currentMessages.length - 2]
|
||||||
|
: this.createUserMessage({
|
||||||
|
messageId: userMessageId,
|
||||||
|
parentMessageId,
|
||||||
|
conversationId,
|
||||||
|
text: message,
|
||||||
|
});
|
||||||
|
|
||||||
if (typeof opts?.getIds === 'function') {
|
if (typeof opts?.getIds === 'function') {
|
||||||
opts.getIds({
|
opts.getIds({
|
||||||
|
|
@ -109,6 +120,7 @@ class BaseClient {
|
||||||
return {
|
return {
|
||||||
...opts,
|
...opts,
|
||||||
user,
|
user,
|
||||||
|
head,
|
||||||
conversationId,
|
conversationId,
|
||||||
responseMessageId,
|
responseMessageId,
|
||||||
saveOptions,
|
saveOptions,
|
||||||
|
|
@ -373,7 +385,7 @@ class BaseClient {
|
||||||
|
|
||||||
if (this.options.debug) {
|
if (this.options.debug) {
|
||||||
console.debug('<-------------------------PAYLOAD/TOKEN COUNT MAP------------------------->');
|
console.debug('<-------------------------PAYLOAD/TOKEN COUNT MAP------------------------->');
|
||||||
console.debug('Payload:', payload);
|
// console.debug('Payload:', payload);
|
||||||
console.debug('Token Count Map:', tokenCountMap);
|
console.debug('Token Count Map:', tokenCountMap);
|
||||||
console.debug('Prompt Tokens', promptTokens, remainingContextTokens, this.maxContextTokens);
|
console.debug('Prompt Tokens', promptTokens, remainingContextTokens, this.maxContextTokens);
|
||||||
}
|
}
|
||||||
|
|
@ -382,13 +394,16 @@ class BaseClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
async sendMessage(message, opts = {}) {
|
async sendMessage(message, opts = {}) {
|
||||||
const { user, conversationId, responseMessageId, saveOptions, userMessage } =
|
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
|
||||||
await this.handleStartMethods(message, opts);
|
await this.handleStartMethods(message, opts);
|
||||||
|
|
||||||
this.user = user;
|
this.user = user;
|
||||||
// It's not necessary to push to currentMessages
|
// It's not necessary to push to currentMessages
|
||||||
// depending on subclass implementation of handling messages
|
// depending on subclass implementation of handling messages
|
||||||
this.currentMessages.push(userMessage);
|
// When this is an edit, all messages are already in currentMessages, both user and response
|
||||||
|
if (!isEdited) {
|
||||||
|
this.currentMessages.push(userMessage);
|
||||||
|
}
|
||||||
|
|
||||||
let {
|
let {
|
||||||
prompt: payload,
|
prompt: payload,
|
||||||
|
|
@ -398,13 +413,13 @@ class BaseClient {
|
||||||
this.currentMessages,
|
this.currentMessages,
|
||||||
// When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
|
// When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
|
||||||
// this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
|
// this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
|
||||||
userMessage.messageId,
|
isEdited ? head : userMessage.messageId,
|
||||||
this.getBuildMessagesOptions(opts),
|
this.getBuildMessagesOptions(opts),
|
||||||
);
|
);
|
||||||
|
|
||||||
if (this.options.debug) {
|
if (this.options.debug) {
|
||||||
console.debug('payload');
|
console.debug('payload');
|
||||||
console.debug(payload);
|
// console.debug(payload);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tokenCountMap) {
|
if (tokenCountMap) {
|
||||||
|
|
@ -423,7 +438,11 @@ class BaseClient {
|
||||||
this.handleTokenCountMap(tokenCountMap);
|
this.handleTokenCountMap(tokenCountMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
if (!isEdited) {
|
||||||
|
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||||
|
}
|
||||||
|
|
||||||
|
const generation = isEdited ? this.currentMessages[this.currentMessages.length - 1].text : '';
|
||||||
const responseMessage = {
|
const responseMessage = {
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
|
@ -431,7 +450,7 @@ class BaseClient {
|
||||||
isCreatedByUser: false,
|
isCreatedByUser: false,
|
||||||
model: this.modelOptions.model,
|
model: this.modelOptions.model,
|
||||||
sender: this.sender,
|
sender: this.sender,
|
||||||
text: await this.sendCompletion(payload, opts),
|
text: addSpaceIfNeeded(generation) + (await this.sendCompletion(payload, opts)),
|
||||||
promptTokens,
|
promptTokens,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -453,7 +472,7 @@ class BaseClient {
|
||||||
console.debug('Loading history for conversation', conversationId, parentMessageId);
|
console.debug('Loading history for conversation', conversationId, parentMessageId);
|
||||||
}
|
}
|
||||||
|
|
||||||
const messages = (await getMessages({ conversationId })) || [];
|
const messages = (await getMessages({ conversationId })) ?? [];
|
||||||
|
|
||||||
if (messages.length === 0) {
|
if (messages.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
|
|
|
||||||
|
|
@ -314,6 +314,7 @@ class OpenAIClient extends BaseClient {
|
||||||
async sendCompletion(payload, opts = {}) {
|
async sendCompletion(payload, opts = {}) {
|
||||||
let reply = '';
|
let reply = '';
|
||||||
let result = null;
|
let result = null;
|
||||||
|
let streamResult = null;
|
||||||
if (typeof opts.onProgress === 'function') {
|
if (typeof opts.onProgress === 'function') {
|
||||||
await this.getCompletion(
|
await this.getCompletion(
|
||||||
payload,
|
payload,
|
||||||
|
|
@ -321,6 +322,10 @@ class OpenAIClient extends BaseClient {
|
||||||
if (progressMessage === '[DONE]') {
|
if (progressMessage === '[DONE]') {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (progressMessage.choices) {
|
||||||
|
streamResult = progressMessage;
|
||||||
|
}
|
||||||
const token = this.isChatCompletion
|
const token = this.isChatCompletion
|
||||||
? progressMessage.choices?.[0]?.delta?.content
|
? progressMessage.choices?.[0]?.delta?.content
|
||||||
: progressMessage.choices?.[0]?.text;
|
: progressMessage.choices?.[0]?.text;
|
||||||
|
|
@ -355,6 +360,10 @@ class OpenAIClient extends BaseClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (streamResult && typeof opts.addMetadata === 'function') {
|
||||||
|
const { finish_reason } = streamResult.choices[0];
|
||||||
|
opts.addMetadata({ finish_reason });
|
||||||
|
}
|
||||||
return reply.trim();
|
return reply.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -345,7 +345,8 @@ Only respond with your conversational reply to the following User Message:
|
||||||
}
|
}
|
||||||
|
|
||||||
async sendMessage(message, opts = {}) {
|
async sendMessage(message, opts = {}) {
|
||||||
const completionMode = this.options.tools.length === 0;
|
// If a message is edited, no tools can be used.
|
||||||
|
const completionMode = this.options.tools.length === 0 || opts.isEdited;
|
||||||
if (completionMode) {
|
if (completionMode) {
|
||||||
this.setOptions(opts);
|
this.setOptions(opts);
|
||||||
return super.sendMessage(message, opts);
|
return super.sendMessage(message, opts);
|
||||||
|
|
|
||||||
139
api/app/clients/specs/AnthropicClient.test.js
Normal file
139
api/app/clients/specs/AnthropicClient.test.js
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
const AnthropicClient = require('../AnthropicClient');
|
||||||
|
const HUMAN_PROMPT = '\n\nHuman:';
|
||||||
|
const AI_PROMPT = '\n\nAssistant:';
|
||||||
|
|
||||||
|
describe('AnthropicClient', () => {
|
||||||
|
let client;
|
||||||
|
const model = 'claude-2';
|
||||||
|
const parentMessageId = '1';
|
||||||
|
const messages = [
|
||||||
|
{ role: 'user', isCreatedByUser: true, text: 'Hello', messageId: parentMessageId },
|
||||||
|
{ role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId },
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
isCreatedByUser: true,
|
||||||
|
text: 'What\'s up',
|
||||||
|
messageId: '3',
|
||||||
|
parentMessageId: '2',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
const options = {
|
||||||
|
modelOptions: {
|
||||||
|
model,
|
||||||
|
temperature: 0.7,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
client = new AnthropicClient('test-api-key');
|
||||||
|
client.setOptions(options);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setOptions', () => {
|
||||||
|
it('should set the options correctly', () => {
|
||||||
|
expect(client.apiKey).toBe('test-api-key');
|
||||||
|
expect(client.modelOptions.model).toBe(model);
|
||||||
|
expect(client.modelOptions.temperature).toBe(0.7);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getSaveOptions', () => {
|
||||||
|
it('should return the correct save options', () => {
|
||||||
|
const options = client.getSaveOptions();
|
||||||
|
expect(options).toHaveProperty('modelLabel');
|
||||||
|
expect(options).toHaveProperty('promptPrefix');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('buildMessages', () => {
|
||||||
|
it('should handle promptPrefix from options when promptPrefix argument is not provided', async () => {
|
||||||
|
client.options.promptPrefix = 'Test Prefix from options';
|
||||||
|
const result = await client.buildMessages(messages, parentMessageId);
|
||||||
|
const { prompt } = result;
|
||||||
|
expect(prompt).toContain('Test Prefix from options');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should build messages correctly for chat completion', async () => {
|
||||||
|
const result = await client.buildMessages(messages, '2');
|
||||||
|
expect(result).toHaveProperty('prompt');
|
||||||
|
expect(result.prompt).toContain(HUMAN_PROMPT);
|
||||||
|
expect(result.prompt).toContain('Hello');
|
||||||
|
expect(result.prompt).toContain(AI_PROMPT);
|
||||||
|
expect(result.prompt).toContain('Hi');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should group messages by the same author', async () => {
|
||||||
|
const groupedMessages = messages.map((m) => ({ ...m, isCreatedByUser: true, role: 'user' }));
|
||||||
|
const result = await client.buildMessages(groupedMessages, '3');
|
||||||
|
expect(result.context).toHaveLength(1);
|
||||||
|
|
||||||
|
// Check that HUMAN_PROMPT appears only once in the prompt
|
||||||
|
const matches = result.prompt.match(new RegExp(HUMAN_PROMPT, 'g'));
|
||||||
|
expect(matches).toHaveLength(1);
|
||||||
|
|
||||||
|
groupedMessages.push({
|
||||||
|
role: 'assistant',
|
||||||
|
isCreatedByUser: false,
|
||||||
|
text: 'I heard you the first time',
|
||||||
|
messageId: '4',
|
||||||
|
parentMessageId: '3',
|
||||||
|
});
|
||||||
|
|
||||||
|
const result2 = await client.buildMessages(groupedMessages, '4');
|
||||||
|
expect(result2.context).toHaveLength(2);
|
||||||
|
|
||||||
|
// Check that HUMAN_PROMPT appears only once in the prompt
|
||||||
|
const human_matches = result2.prompt.match(new RegExp(HUMAN_PROMPT, 'g'));
|
||||||
|
const ai_matches = result2.prompt.match(new RegExp(AI_PROMPT, 'g'));
|
||||||
|
expect(human_matches).toHaveLength(1);
|
||||||
|
expect(ai_matches).toHaveLength(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle isEdited condition', async () => {
|
||||||
|
const editedMessages = [
|
||||||
|
{ role: 'user', isCreatedByUser: true, text: 'Hello', messageId: '1' },
|
||||||
|
{ role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId },
|
||||||
|
];
|
||||||
|
|
||||||
|
const trimmedLabel = AI_PROMPT.trim();
|
||||||
|
const result = await client.buildMessages(editedMessages, '2');
|
||||||
|
expect(result.prompt.trim().endsWith(trimmedLabel)).toBeFalsy();
|
||||||
|
|
||||||
|
// Add a human message at the end to test the opposite
|
||||||
|
editedMessages.push({
|
||||||
|
role: 'user',
|
||||||
|
isCreatedByUser: true,
|
||||||
|
text: 'Hi again',
|
||||||
|
messageId: '3',
|
||||||
|
parentMessageId: '2',
|
||||||
|
});
|
||||||
|
const result2 = await client.buildMessages(editedMessages, '3');
|
||||||
|
expect(result2.prompt.trim().endsWith(trimmedLabel)).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should build messages correctly with a promptPrefix', async () => {
|
||||||
|
const promptPrefix = 'Test Prefix';
|
||||||
|
client.options.promptPrefix = promptPrefix;
|
||||||
|
const result = await client.buildMessages(messages, parentMessageId);
|
||||||
|
const { prompt } = result;
|
||||||
|
expect(prompt).toBeDefined();
|
||||||
|
expect(prompt).toContain(promptPrefix);
|
||||||
|
const textAfterPrefix = prompt.split(promptPrefix)[1];
|
||||||
|
expect(textAfterPrefix).toContain(AI_PROMPT);
|
||||||
|
|
||||||
|
const editedMessages = messages.slice(0, -1);
|
||||||
|
const result2 = await client.buildMessages(editedMessages, parentMessageId);
|
||||||
|
const textAfterPrefix2 = result2.prompt.split(promptPrefix)[1];
|
||||||
|
expect(textAfterPrefix2).toContain(AI_PROMPT);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle identityPrefix from options', async () => {
|
||||||
|
client.options.userLabel = 'John';
|
||||||
|
client.options.modelLabel = 'Claude-2';
|
||||||
|
const result = await client.buildMessages(messages, parentMessageId);
|
||||||
|
const { prompt } = result;
|
||||||
|
expect(prompt).toContain('Human\'s name: John');
|
||||||
|
expect(prompt).toContain('You are Claude-2');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -45,6 +45,18 @@ const fakeMessages = [];
|
||||||
const userMessage = 'Hello, ChatGPT!';
|
const userMessage = 'Hello, ChatGPT!';
|
||||||
const apiKey = 'fake-api-key';
|
const apiKey = 'fake-api-key';
|
||||||
|
|
||||||
|
const messageHistory = [
|
||||||
|
{ role: 'user', isCreatedByUser: true, text: 'Hello', messageId: '1' },
|
||||||
|
{ role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId: '1' },
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
isCreatedByUser: true,
|
||||||
|
text: 'What\'s up',
|
||||||
|
messageId: '3',
|
||||||
|
parentMessageId: '2',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
describe('BaseClient', () => {
|
describe('BaseClient', () => {
|
||||||
let TestClient;
|
let TestClient;
|
||||||
const options = {
|
const options = {
|
||||||
|
|
@ -277,9 +289,54 @@ describe('BaseClient', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
test('should return chat history', async () => {
|
test('should return chat history', async () => {
|
||||||
const chatMessages = await TestClient.loadHistory(conversationId, parentMessageId);
|
TestClient = initializeFakeClient(apiKey, options, messageHistory);
|
||||||
expect(TestClient.currentMessages).toHaveLength(4);
|
const chatMessages = await TestClient.loadHistory(conversationId, '2');
|
||||||
expect(chatMessages[0].text).toEqual(userMessage);
|
expect(TestClient.currentMessages).toHaveLength(2);
|
||||||
|
expect(chatMessages[0].text).toEqual('Hello');
|
||||||
|
|
||||||
|
const chatMessages2 = await TestClient.loadHistory(conversationId, '3');
|
||||||
|
expect(TestClient.currentMessages).toHaveLength(3);
|
||||||
|
expect(chatMessages2[chatMessages2.length - 1].text).toEqual('What\'s up');
|
||||||
|
});
|
||||||
|
|
||||||
|
/* Most of the new sendMessage logic revolving around edited/continued AI messages
|
||||||
|
* can be summarized by the following test. The condition will load the entire history up to
|
||||||
|
* the message that is being edited, which will trigger the AI API to 'continue' the response.
|
||||||
|
* The 'userMessage' is only passed by convention and is not necessary for the generation.
|
||||||
|
*/
|
||||||
|
it('should not push userMessage to currentMessages when isEdited is true and vice versa', async () => {
|
||||||
|
const overrideParentMessageId = 'user-message-id';
|
||||||
|
const responseMessageId = 'response-message-id';
|
||||||
|
const newHistory = messageHistory.slice();
|
||||||
|
newHistory.push({
|
||||||
|
role: 'assistant',
|
||||||
|
isCreatedByUser: false,
|
||||||
|
text: 'test message',
|
||||||
|
messageId: responseMessageId,
|
||||||
|
parentMessageId: '3',
|
||||||
|
});
|
||||||
|
|
||||||
|
TestClient = initializeFakeClient(apiKey, options, newHistory);
|
||||||
|
const sendMessageOptions = {
|
||||||
|
isEdited: true,
|
||||||
|
overrideParentMessageId,
|
||||||
|
parentMessageId: '3',
|
||||||
|
responseMessageId,
|
||||||
|
};
|
||||||
|
|
||||||
|
await TestClient.sendMessage('test message', sendMessageOptions);
|
||||||
|
const currentMessages = TestClient.currentMessages;
|
||||||
|
expect(currentMessages[currentMessages.length - 1].messageId).not.toEqual(
|
||||||
|
overrideParentMessageId,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test the opposite case
|
||||||
|
sendMessageOptions.isEdited = false;
|
||||||
|
await TestClient.sendMessage('test message', sendMessageOptions);
|
||||||
|
const currentMessages2 = TestClient.currentMessages;
|
||||||
|
expect(currentMessages2[currentMessages2.length - 1].messageId).toEqual(
|
||||||
|
overrideParentMessageId,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('setOptions is called with the correct arguments', async () => {
|
test('setOptions is called with the correct arguments', async () => {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
const crypto = require('crypto');
|
|
||||||
const BaseClient = require('../BaseClient');
|
const BaseClient = require('../BaseClient');
|
||||||
const { maxTokensMap } = require('../../../utils');
|
const { maxTokensMap } = require('../../../utils');
|
||||||
|
|
||||||
|
|
@ -87,86 +86,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
|
||||||
return 'Mock response text';
|
return 'Mock response text';
|
||||||
});
|
});
|
||||||
|
|
||||||
TestClient.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
|
|
||||||
if (opts && typeof opts === 'object') {
|
|
||||||
TestClient.setOptions(opts);
|
|
||||||
}
|
|
||||||
|
|
||||||
const user = opts.user || null;
|
|
||||||
const conversationId = opts.conversationId || crypto.randomUUID();
|
|
||||||
const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000';
|
|
||||||
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
|
|
||||||
const saveOptions = TestClient.getSaveOptions();
|
|
||||||
|
|
||||||
this.pastMessages = await TestClient.loadHistory(
|
|
||||||
conversationId,
|
|
||||||
TestClient.options?.parentMessageId,
|
|
||||||
);
|
|
||||||
|
|
||||||
const userMessage = {
|
|
||||||
text: message,
|
|
||||||
sender: TestClient.sender,
|
|
||||||
isCreatedByUser: true,
|
|
||||||
messageId: userMessageId,
|
|
||||||
parentMessageId,
|
|
||||||
conversationId,
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = {
|
|
||||||
sender: TestClient.sender,
|
|
||||||
text: 'Hello, User!',
|
|
||||||
isCreatedByUser: false,
|
|
||||||
messageId: crypto.randomUUID(),
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
conversationId,
|
|
||||||
};
|
|
||||||
|
|
||||||
fakeMessages.push(userMessage);
|
|
||||||
fakeMessages.push(response);
|
|
||||||
|
|
||||||
if (typeof opts.getIds === 'function') {
|
|
||||||
opts.getIds({
|
|
||||||
userMessage,
|
|
||||||
conversationId,
|
|
||||||
responseMessageId: response.messageId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof opts.onStart === 'function') {
|
|
||||||
opts.onStart(userMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
let { prompt: payload, tokenCountMap } = await TestClient.buildMessages(
|
|
||||||
this.currentMessages,
|
|
||||||
userMessage.messageId,
|
|
||||||
TestClient.getBuildMessagesOptions(opts),
|
|
||||||
);
|
|
||||||
|
|
||||||
if (tokenCountMap) {
|
|
||||||
payload = payload.map((message, i) => {
|
|
||||||
const { tokenCount, ...messageWithoutTokenCount } = message;
|
|
||||||
// userMessage is always the last one in the payload
|
|
||||||
if (i === payload.length - 1) {
|
|
||||||
userMessage.tokenCount = message.tokenCount;
|
|
||||||
console.debug(
|
|
||||||
`Token count for user message: ${tokenCount}`,
|
|
||||||
`Instruction Tokens: ${tokenCountMap.instructions || 'N/A'}`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return messageWithoutTokenCount;
|
|
||||||
});
|
|
||||||
TestClient.handleTokenCountMap(tokenCountMap);
|
|
||||||
}
|
|
||||||
|
|
||||||
await TestClient.saveMessageToDatabase(userMessage, saveOptions, user);
|
|
||||||
response.text = await TestClient.sendCompletion(payload, opts);
|
|
||||||
if (tokenCountMap && TestClient.getTokenCountForResponse) {
|
|
||||||
response.tokenCount = TestClient.getTokenCountForResponse(response);
|
|
||||||
}
|
|
||||||
await TestClient.saveMessageToDatabase(response, saveOptions, user);
|
|
||||||
return response;
|
|
||||||
});
|
|
||||||
|
|
||||||
TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
|
TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
|
||||||
const orderedMessages = TestClient.constructor.getMessagesForConversation(
|
const orderedMessages = TestClient.constructor.getMessagesForConversation(
|
||||||
messages,
|
messages,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
const OpenAIClient = require('../OpenAIClient');
|
const OpenAIClient = require('../OpenAIClient');
|
||||||
|
|
||||||
|
jest.mock('meilisearch');
|
||||||
|
|
||||||
describe('OpenAIClient', () => {
|
describe('OpenAIClient', () => {
|
||||||
let client, client2;
|
let client, client2;
|
||||||
const model = 'gpt-4';
|
const model = 'gpt-4';
|
||||||
|
|
@ -25,6 +27,9 @@ describe('OpenAIClient', () => {
|
||||||
content: 'Refined answer',
|
content: 'Refined answer',
|
||||||
tokenCount: 30,
|
tokenCount: 30,
|
||||||
});
|
});
|
||||||
|
client.buildPrompt = jest
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') });
|
||||||
client.constructor.freeAndResetAllEncoders();
|
client.constructor.freeAndResetAllEncoders();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,6 @@ describe('PluginsClient', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const response = await TestAgent.sendMessage(userMessage);
|
const response = await TestAgent.sendMessage(userMessage);
|
||||||
console.log(response);
|
|
||||||
parentMessageId = response.messageId;
|
parentMessageId = response.messageId;
|
||||||
conversationId = response.conversationId;
|
conversationId = response.conversationId;
|
||||||
expect(response).toEqual(expectedResult);
|
expect(response).toEqual(expectedResult);
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ async function addOpenAPISpecs(availableTools) {
|
||||||
}
|
}
|
||||||
return availableTools;
|
return availableTools;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log('addOpenAPISpecs error', error);
|
|
||||||
return availableTools;
|
return availableTools;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,6 @@ describe('Tool Handlers', () => {
|
||||||
it('returns valid tools given input tools and user authentication', async () => {
|
it('returns valid tools given input tools and user authentication', async () => {
|
||||||
const validTools = await validateTools(fakeUser._id, initialTools);
|
const validTools = await validateTools(fakeUser._id, initialTools);
|
||||||
expect(validTools).toBeDefined();
|
expect(validTools).toBeDefined();
|
||||||
console.log('validateTools: validTools', validTools);
|
|
||||||
expect(validTools.some((tool) => tool === pluginKey)).toBeTruthy();
|
expect(validTools.some((tool) => tool === pluginKey)).toBeTruthy();
|
||||||
expect(validTools.length).toBeGreaterThan(0);
|
expect(validTools.length).toBeGreaterThan(0);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -3,15 +3,11 @@ const { askBing } = require('./bingai');
|
||||||
const clients = require('./clients');
|
const clients = require('./clients');
|
||||||
const titleConvo = require('./titleConvo');
|
const titleConvo = require('./titleConvo');
|
||||||
const titleConvoBing = require('./titleConvoBing');
|
const titleConvoBing = require('./titleConvoBing');
|
||||||
const getCitations = require('../lib/parse/getCitations');
|
|
||||||
const citeText = require('../lib/parse/citeText');
|
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
browserClient,
|
browserClient,
|
||||||
askBing,
|
askBing,
|
||||||
titleConvo,
|
titleConvo,
|
||||||
titleConvoBing,
|
titleConvoBing,
|
||||||
getCitations,
|
|
||||||
citeText,
|
|
||||||
...clients,
|
...clients,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
// const regex = / \[\d+\..*?\]\(.*?\)/g;
|
|
||||||
const regex = / \[.*?]\(.*?\)/g;
|
|
||||||
|
|
||||||
const getCitations = (res) => {
|
|
||||||
const adaptiveCards = res.details.adaptiveCards;
|
|
||||||
const textBlocks = adaptiveCards && adaptiveCards[0].body;
|
|
||||||
if (!textBlocks) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
let links = textBlocks[textBlocks.length - 1]?.text.match(regex);
|
|
||||||
if (links?.length === 0 || !links) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
links = links.map((link) => link.trim());
|
|
||||||
return links.join('\n - ');
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = getCitations;
|
|
||||||
|
|
@ -14,6 +14,7 @@ module.exports = {
|
||||||
error,
|
error,
|
||||||
unfinished,
|
unfinished,
|
||||||
cancelled,
|
cancelled,
|
||||||
|
finish_reason = null,
|
||||||
tokenCount = null,
|
tokenCount = null,
|
||||||
plugin = null,
|
plugin = null,
|
||||||
model = null,
|
model = null,
|
||||||
|
|
@ -29,6 +30,7 @@ module.exports = {
|
||||||
sender,
|
sender,
|
||||||
text,
|
text,
|
||||||
isCreatedByUser,
|
isCreatedByUser,
|
||||||
|
finish_reason,
|
||||||
error,
|
error,
|
||||||
unfinished,
|
unfinished,
|
||||||
cancelled,
|
cancelled,
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,9 @@ const messageSchema = mongoose.Schema(
|
||||||
type: Boolean,
|
type: Boolean,
|
||||||
default: false,
|
default: false,
|
||||||
},
|
},
|
||||||
|
finish_reason: {
|
||||||
|
type: String,
|
||||||
|
},
|
||||||
_meiliIndex: {
|
_meiliIndex: {
|
||||||
type: Boolean,
|
type: Boolean,
|
||||||
required: false,
|
required: false,
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,7 @@ config.validate(); // Validate the config
|
||||||
app.use('/api/user', routes.user);
|
app.use('/api/user', routes.user);
|
||||||
app.use('/api/search', routes.search);
|
app.use('/api/search', routes.search);
|
||||||
app.use('/api/ask', routes.ask);
|
app.use('/api/ask', routes.ask);
|
||||||
|
app.use('/api/edit', routes.edit);
|
||||||
app.use('/api/messages', routes.messages);
|
app.use('/api/messages', routes.messages);
|
||||||
app.use('/api/convos', routes.convos);
|
app.use('/api/convos', routes.convos);
|
||||||
app.use('/api/presets', routes.presets);
|
app.use('/api/presets', routes.presets);
|
||||||
|
|
|
||||||
2
api/server/middleware/abortControllers.js
Normal file
2
api/server/middleware/abortControllers.js
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
// abortControllers.js
|
||||||
|
module.exports = new Map();
|
||||||
106
api/server/middleware/abortMiddleware.js
Normal file
106
api/server/middleware/abortMiddleware.js
Normal file
|
|
@ -0,0 +1,106 @@
|
||||||
|
const { saveMessage, getConvo, getConvoTitle } = require('../../models');
|
||||||
|
const { sendMessage, handleError } = require('../utils');
|
||||||
|
const abortControllers = require('./abortControllers');
|
||||||
|
|
||||||
|
async function abortMessage(req, res) {
|
||||||
|
const { abortKey } = req.body;
|
||||||
|
|
||||||
|
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||||
|
return res.status(404).send('Request not found');
|
||||||
|
}
|
||||||
|
|
||||||
|
const { abortController } = abortControllers.get(abortKey);
|
||||||
|
const ret = await abortController.abortCompletion();
|
||||||
|
console.log('Aborted request', abortKey);
|
||||||
|
abortControllers.delete(abortKey);
|
||||||
|
res.send(JSON.stringify(ret));
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleAbort = () => {
|
||||||
|
return async (req, res) => {
|
||||||
|
try {
|
||||||
|
return await abortMessage(req, res);
|
||||||
|
} catch (err) {
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const createAbortController = (res, req, endpointOption, getAbortData) => {
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const onStart = (userMessage) => {
|
||||||
|
sendMessage(res, { message: userMessage, created: true });
|
||||||
|
abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption });
|
||||||
|
|
||||||
|
res.on('finish', function () {
|
||||||
|
abortControllers.delete(userMessage.conversationId);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
abortController.abortCompletion = async function () {
|
||||||
|
abortController.abort();
|
||||||
|
const { conversationId, userMessage, ...responseData } = getAbortData();
|
||||||
|
|
||||||
|
const responseMessage = {
|
||||||
|
...responseData,
|
||||||
|
finish_reason: 'incomplete',
|
||||||
|
model: endpointOption.modelOptions.model,
|
||||||
|
unfinished: false,
|
||||||
|
cancelled: true,
|
||||||
|
error: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
saveMessage(responseMessage);
|
||||||
|
|
||||||
|
return {
|
||||||
|
title: await getConvoTitle(req.user.id, conversationId),
|
||||||
|
final: true,
|
||||||
|
conversation: await getConvo(req.user.id, conversationId),
|
||||||
|
requestMessage: userMessage,
|
||||||
|
responseMessage: responseMessage,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
return { abortController, onStart };
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleAbortError = async (res, req, error, data) => {
|
||||||
|
console.error(error);
|
||||||
|
const { sender, conversationId, messageId, parentMessageId, partialText } = data;
|
||||||
|
|
||||||
|
const respondWithError = async () => {
|
||||||
|
const errorMessage = {
|
||||||
|
sender,
|
||||||
|
messageId,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId,
|
||||||
|
unfinished: false,
|
||||||
|
cancelled: false,
|
||||||
|
error: true,
|
||||||
|
text: error.message,
|
||||||
|
};
|
||||||
|
if (abortControllers.has(conversationId)) {
|
||||||
|
const { abortController } = abortControllers.get(conversationId);
|
||||||
|
abortController.abort();
|
||||||
|
abortControllers.delete(conversationId);
|
||||||
|
}
|
||||||
|
await saveMessage(errorMessage);
|
||||||
|
handleError(res, errorMessage);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (partialText?.length > 2) {
|
||||||
|
try {
|
||||||
|
return await abortMessage(req, res);
|
||||||
|
} catch (err) {
|
||||||
|
return respondWithError();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return respondWithError();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
handleAbort,
|
||||||
|
createAbortController,
|
||||||
|
handleAbortError,
|
||||||
|
};
|
||||||
20
api/server/middleware/buildEndpointOption.js
Normal file
20
api/server/middleware/buildEndpointOption.js
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
const openAI = require('../routes/endpoints/openAI');
|
||||||
|
const gptPlugins = require('../routes/endpoints/gptPlugins');
|
||||||
|
const anthropic = require('../routes/endpoints/anthropic');
|
||||||
|
const { parseConvo } = require('../routes/endpoints/schemas');
|
||||||
|
|
||||||
|
const buildFunction = {
|
||||||
|
openAI: openAI.buildOptions,
|
||||||
|
azureOpenAI: openAI.buildOptions,
|
||||||
|
gptPlugins: gptPlugins.buildOptions,
|
||||||
|
anthropic: anthropic.buildOptions,
|
||||||
|
};
|
||||||
|
|
||||||
|
function buildEndpointOption(req, res, next) {
|
||||||
|
const { endpoint } = req.body;
|
||||||
|
const parsedBody = parseConvo(endpoint, req.body);
|
||||||
|
req.body.endpointOption = buildFunction[endpoint](endpoint, parsedBody);
|
||||||
|
next();
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = buildEndpointOption;
|
||||||
17
api/server/middleware/index.js
Normal file
17
api/server/middleware/index.js
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
const abortMiddleware = require('./abortMiddleware');
|
||||||
|
const setHeaders = require('./setHeaders');
|
||||||
|
const requireJwtAuth = require('./requireJwtAuth');
|
||||||
|
const requireLocalAuth = require('./requireLocalAuth');
|
||||||
|
const validateEndpoint = require('./validateEndpoint');
|
||||||
|
const buildEndpointOption = require('./buildEndpointOption');
|
||||||
|
const validateRegistration = require('./validateRegistration');
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
...abortMiddleware,
|
||||||
|
setHeaders,
|
||||||
|
requireJwtAuth,
|
||||||
|
requireLocalAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
validateRegistration,
|
||||||
|
};
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
const passport = require('passport');
|
const passport = require('passport');
|
||||||
const DebugControl = require('../utils/debug.js');
|
const DebugControl = require('../../utils/debug.js');
|
||||||
|
|
||||||
function log({ title, parameters }) {
|
function log({ title, parameters }) {
|
||||||
DebugControl.log.functionName(title);
|
DebugControl.log.functionName(title);
|
||||||
12
api/server/middleware/setHeaders.js
Normal file
12
api/server/middleware/setHeaders.js
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
function setHeaders(req, res, next) {
|
||||||
|
res.writeHead(200, {
|
||||||
|
Connection: 'keep-alive',
|
||||||
|
'Content-Type': 'text/event-stream',
|
||||||
|
'Cache-Control': 'no-cache, no-transform',
|
||||||
|
'Access-Control-Allow-Origin': '*',
|
||||||
|
'X-Accel-Buffering': 'no',
|
||||||
|
});
|
||||||
|
next();
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = setHeaders;
|
||||||
19
api/server/middleware/validateEndpoint.js
Normal file
19
api/server/middleware/validateEndpoint.js
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
const { handleError } = require('../utils');
|
||||||
|
|
||||||
|
function validateEndpoint(req, res, next) {
|
||||||
|
const { endpoint } = req.body;
|
||||||
|
|
||||||
|
if (!req.body.text || req.body.text.length === 0) {
|
||||||
|
return handleError(res, { text: 'Prompt empty or too short' });
|
||||||
|
}
|
||||||
|
|
||||||
|
const pathEndpoint = req.baseUrl.split('/')[3];
|
||||||
|
|
||||||
|
if (endpoint !== pathEndpoint) {
|
||||||
|
return handleError(res, { text: 'Illegal request: Endpoint mismatch' });
|
||||||
|
}
|
||||||
|
|
||||||
|
next();
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = validateEndpoint;
|
||||||
|
|
@ -1,72 +1,43 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const crypto = require('crypto');
|
const { getResponseSender } = require('../endpoints/schemas');
|
||||||
const { titleConvo, AnthropicClient } = require('../../../app');
|
const { initializeClient } = require('../endpoints/anthropic');
|
||||||
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
|
const {
|
||||||
const { abortMessage } = require('../../../utils');
|
handleAbort,
|
||||||
|
createAbortController,
|
||||||
|
handleAbortError,
|
||||||
|
setHeaders,
|
||||||
|
requireJwtAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
} = require('../../middleware');
|
||||||
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
||||||
const { handleError, sendMessage, createOnProgress } = require('./handlers');
|
const { sendMessage, createOnProgress } = require('../../utils');
|
||||||
|
|
||||||
const abortControllers = new Map();
|
router.post('/abort', requireJwtAuth, handleAbort());
|
||||||
|
|
||||||
router.post('/abort', requireJwtAuth, async (req, res) => {
|
router.post(
|
||||||
try {
|
'/',
|
||||||
return await abortMessage(req, res, abortControllers);
|
requireJwtAuth,
|
||||||
} catch (err) {
|
validateEndpoint,
|
||||||
console.error(err);
|
buildEndpointOption,
|
||||||
}
|
setHeaders,
|
||||||
});
|
async (req, res) => {
|
||||||
|
let {
|
||||||
|
text,
|
||||||
|
endpointOption,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId = null,
|
||||||
|
overrideParentMessageId = null,
|
||||||
|
} = req.body;
|
||||||
|
console.log('ask log');
|
||||||
|
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||||
|
let userMessage;
|
||||||
|
let userMessageId;
|
||||||
|
let responseMessageId;
|
||||||
|
let lastSavedTimestamp = 0;
|
||||||
|
let saveDelay = 100;
|
||||||
|
|
||||||
router.post('/', requireJwtAuth, async (req, res) => {
|
|
||||||
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
|
|
||||||
if (text.length === 0) {
|
|
||||||
return handleError(res, { text: 'Prompt empty or too short' });
|
|
||||||
}
|
|
||||||
if (endpoint !== 'anthropic') {
|
|
||||||
return handleError(res, { text: 'Illegal request' });
|
|
||||||
}
|
|
||||||
|
|
||||||
const endpointOption = {
|
|
||||||
promptPrefix: req.body?.promptPrefix ?? null,
|
|
||||||
modelLabel: req.body?.modelLabel ?? null,
|
|
||||||
token: req.body?.token ?? null,
|
|
||||||
modelOptions: {
|
|
||||||
model: req.body?.model ?? 'claude-1',
|
|
||||||
temperature: req.body?.temperature ?? 1,
|
|
||||||
maxOutputTokens: req.body?.maxOutputTokens ?? 1024,
|
|
||||||
topP: req.body?.topP ?? 0.7,
|
|
||||||
topK: req.body?.topK ?? 5,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
const conversationId = oldConversationId || crypto.randomUUID();
|
|
||||||
|
|
||||||
return await ask({
|
|
||||||
text,
|
|
||||||
endpointOption,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
req,
|
|
||||||
res,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
|
|
||||||
res.writeHead(200, {
|
|
||||||
Connection: 'keep-alive',
|
|
||||||
'Content-Type': 'text/event-stream',
|
|
||||||
'Cache-Control': 'no-cache, no-transform',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'X-Accel-Buffering': 'no',
|
|
||||||
});
|
|
||||||
|
|
||||||
let userMessage;
|
|
||||||
let userMessageId;
|
|
||||||
let responseMessageId;
|
|
||||||
let lastSavedTimestamp = 0;
|
|
||||||
const { overrideParentMessageId = null } = req.body;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const getIds = (data) => {
|
const getIds = (data) => {
|
||||||
userMessage = data.userMessage;
|
userMessage = data.userMessage;
|
||||||
userMessageId = data.userMessage.messageId;
|
userMessageId = data.userMessage.messageId;
|
||||||
|
|
@ -79,116 +50,95 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
|
||||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||||
onProgress: ({ text: partialText }) => {
|
onProgress: ({ text: partialText }) => {
|
||||||
const currentTimestamp = Date.now();
|
const currentTimestamp = Date.now();
|
||||||
if (currentTimestamp - lastSavedTimestamp > 500) {
|
|
||||||
|
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||||
lastSavedTimestamp = currentTimestamp;
|
lastSavedTimestamp = currentTimestamp;
|
||||||
saveMessage({
|
saveMessage({
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender: 'Anthropic',
|
sender: getResponseSender(endpointOption),
|
||||||
conversationId,
|
conversationId,
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
text: partialText,
|
text: partialText,
|
||||||
unfinished: true,
|
unfinished: true,
|
||||||
cancelled: false,
|
cancelled: false,
|
||||||
error: false,
|
error: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (saveDelay < 500) {
|
||||||
|
saveDelay = 500;
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
try {
|
||||||
const abortController = new AbortController();
|
const getAbortData = () => ({
|
||||||
abortController.abortAsk = async function () {
|
|
||||||
this.abort();
|
|
||||||
|
|
||||||
const responseMessage = {
|
|
||||||
messageId: responseMessageId,
|
|
||||||
sender: 'Anthropic',
|
|
||||||
conversationId,
|
conversationId,
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
messageId: responseMessageId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
text: getPartialText(),
|
text: getPartialText(),
|
||||||
model: endpointOption.modelOptions.model,
|
userMessage,
|
||||||
unfinished: false,
|
});
|
||||||
cancelled: true,
|
|
||||||
error: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
saveMessage(responseMessage);
|
const { abortController, onStart } = createAbortController(
|
||||||
|
res,
|
||||||
|
req,
|
||||||
|
endpointOption,
|
||||||
|
getAbortData,
|
||||||
|
);
|
||||||
|
|
||||||
return {
|
const { client } = initializeClient(req, endpointOption);
|
||||||
|
|
||||||
|
let response = await client.sendMessage(text, {
|
||||||
|
getIds,
|
||||||
|
debug: false,
|
||||||
|
user: req.user.id,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId,
|
||||||
|
overrideParentMessageId,
|
||||||
|
...endpointOption,
|
||||||
|
onProgress: progressCallback.call(null, {
|
||||||
|
res,
|
||||||
|
text,
|
||||||
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
|
}),
|
||||||
|
onStart,
|
||||||
|
abortController,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (overrideParentMessageId) {
|
||||||
|
response.parentMessageId = overrideParentMessageId;
|
||||||
|
}
|
||||||
|
|
||||||
|
await saveConvo(req.user.id, {
|
||||||
|
...endpointOption,
|
||||||
|
...endpointOption.modelOptions,
|
||||||
|
conversationId,
|
||||||
|
endpoint: 'anthropic',
|
||||||
|
});
|
||||||
|
|
||||||
|
await saveMessage(response);
|
||||||
|
sendMessage(res, {
|
||||||
title: await getConvoTitle(req.user.id, conversationId),
|
title: await getConvoTitle(req.user.id, conversationId),
|
||||||
final: true,
|
final: true,
|
||||||
conversation: await getConvo(req.user.id, conversationId),
|
conversation: await getConvo(req.user.id, conversationId),
|
||||||
requestMessage: userMessage,
|
requestMessage: userMessage,
|
||||||
responseMessage: responseMessage,
|
responseMessage: response,
|
||||||
};
|
});
|
||||||
};
|
res.end();
|
||||||
|
|
||||||
const onStart = (userMessage) => {
|
// TODO: add anthropic titling
|
||||||
sendMessage(res, { message: userMessage, created: true });
|
} catch (error) {
|
||||||
abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption });
|
const partialText = getPartialText();
|
||||||
};
|
handleAbortError(res, req, error, {
|
||||||
|
partialText,
|
||||||
const client = new AnthropicClient(endpointOption.token);
|
|
||||||
|
|
||||||
let response = await client.sendMessage(text, {
|
|
||||||
getIds,
|
|
||||||
debug: false,
|
|
||||||
user: req.user.id,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
overrideParentMessageId,
|
|
||||||
...endpointOption,
|
|
||||||
onProgress: progressCallback.call(null, {
|
|
||||||
res,
|
|
||||||
text,
|
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
}),
|
|
||||||
onStart,
|
|
||||||
abortController,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (overrideParentMessageId) {
|
|
||||||
response.parentMessageId = overrideParentMessageId;
|
|
||||||
}
|
|
||||||
|
|
||||||
await saveConvo(req.user.id, {
|
|
||||||
...endpointOption,
|
|
||||||
...endpointOption.modelOptions,
|
|
||||||
conversationId,
|
|
||||||
endpoint: 'anthropic',
|
|
||||||
});
|
|
||||||
|
|
||||||
await saveMessage(response);
|
|
||||||
sendMessage(res, {
|
|
||||||
title: await getConvoTitle(req.user.id, conversationId),
|
|
||||||
final: true,
|
|
||||||
conversation: await getConvo(req.user.id, conversationId),
|
|
||||||
requestMessage: userMessage,
|
|
||||||
responseMessage: response,
|
|
||||||
});
|
|
||||||
res.end();
|
|
||||||
|
|
||||||
if (parentMessageId == '00000000-0000-0000-0000-000000000000') {
|
|
||||||
const title = await titleConvo({ text, response });
|
|
||||||
await saveConvo(req.user.id, {
|
|
||||||
conversationId,
|
conversationId,
|
||||||
title,
|
sender: getResponseSender(endpointOption),
|
||||||
|
messageId: responseMessageId,
|
||||||
|
parentMessageId: userMessageId,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} catch (error) {
|
},
|
||||||
console.error(error);
|
);
|
||||||
const errorMessage = {
|
|
||||||
messageId: responseMessageId,
|
|
||||||
sender: 'Anthropic',
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
unfinished: false,
|
|
||||||
cancelled: false,
|
|
||||||
error: true,
|
|
||||||
text: error.message,
|
|
||||||
};
|
|
||||||
await saveMessage(errorMessage);
|
|
||||||
handleError(res, errorMessage);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
// const { getChatGPTBrowserModels } = require('../endpoints');
|
|
||||||
const { browserClient } = require('../../../app/');
|
const { browserClient } = require('../../../app/');
|
||||||
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
||||||
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
|
const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils');
|
||||||
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
|
const { requireJwtAuth, setHeaders } = require('../../middleware');
|
||||||
|
|
||||||
router.post('/', requireJwtAuth, async (req, res) => {
|
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
|
||||||
const {
|
const {
|
||||||
endpoint,
|
endpoint,
|
||||||
text,
|
text,
|
||||||
|
|
@ -86,15 +85,6 @@ const ask = async ({
|
||||||
}) => {
|
}) => {
|
||||||
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
|
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
|
||||||
const userId = req.user.id;
|
const userId = req.user.id;
|
||||||
|
|
||||||
res.writeHead(200, {
|
|
||||||
Connection: 'keep-alive',
|
|
||||||
'Content-Type': 'text/event-stream',
|
|
||||||
'Cache-Control': 'no-cache, no-transform',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'X-Accel-Buffering': 'no',
|
|
||||||
});
|
|
||||||
|
|
||||||
let responseMessageId = crypto.randomUUID();
|
let responseMessageId = crypto.randomUUID();
|
||||||
let getPartialMessage = null;
|
let getPartialMessage = null;
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,10 @@ const crypto = require('crypto');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { titleConvoBing, askBing } = require('../../../app');
|
const { titleConvoBing, askBing } = require('../../../app');
|
||||||
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
||||||
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
|
const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils');
|
||||||
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
|
const { requireJwtAuth, setHeaders } = require('../../middleware');
|
||||||
|
|
||||||
router.post('/', requireJwtAuth, async (req, res) => {
|
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
|
||||||
const {
|
const {
|
||||||
endpoint,
|
endpoint,
|
||||||
text,
|
text,
|
||||||
|
|
@ -103,14 +103,6 @@ const ask = async ({
|
||||||
|
|
||||||
let responseMessageId = crypto.randomUUID();
|
let responseMessageId = crypto.randomUUID();
|
||||||
|
|
||||||
res.writeHead(200, {
|
|
||||||
Connection: 'keep-alive',
|
|
||||||
'Content-Type': 'text/event-stream',
|
|
||||||
'Cache-Control': 'no-cache, no-transform',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'X-Accel-Buffering': 'no',
|
|
||||||
});
|
|
||||||
|
|
||||||
if (preSendRequest) {
|
if (preSendRequest) {
|
||||||
sendMessage(res, { message: userMessage, created: true });
|
sendMessage(res, { message: userMessage, created: true });
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,11 @@ const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
const { titleConvo, GoogleClient } = require('../../../app');
|
const { titleConvo, GoogleClient } = require('../../../app');
|
||||||
// const GoogleClient = require('../../../app/google/GoogleClient');
|
|
||||||
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
||||||
const { handleError, sendMessage, createOnProgress } = require('./handlers');
|
const { handleError, sendMessage, createOnProgress } = require('../../utils');
|
||||||
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
|
const { requireJwtAuth, setHeaders } = require('../../middleware');
|
||||||
|
|
||||||
router.post('/', requireJwtAuth, async (req, res) => {
|
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
|
||||||
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
|
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
|
||||||
if (text.length === 0) {
|
if (text.length === 0) {
|
||||||
return handleError(res, { text: 'Prompt empty or too short' });
|
return handleError(res, { text: 'Prompt empty or too short' });
|
||||||
|
|
@ -50,13 +49,6 @@ router.post('/', requireJwtAuth, async (req, res) => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
|
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
|
||||||
res.writeHead(200, {
|
|
||||||
Connection: 'keep-alive',
|
|
||||||
'Content-Type': 'text/event-stream',
|
|
||||||
'Cache-Control': 'no-cache, no-transform',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'X-Accel-Buffering': 'no',
|
|
||||||
});
|
|
||||||
let userMessage;
|
let userMessage;
|
||||||
let userMessageId;
|
let userMessageId;
|
||||||
let responseMessageId;
|
let responseMessageId;
|
||||||
|
|
|
||||||
|
|
@ -1,112 +1,56 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { titleConvo, validateTools, PluginsClient } = require('../../../app');
|
const { getResponseSender } = require('../endpoints/schemas');
|
||||||
const { abortMessage, getAzureCredentials } = require('../../../utils');
|
const { validateTools } = require('../../../app');
|
||||||
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
const { addTitle } = require('../endpoints/openAI');
|
||||||
|
const { initializeClient } = require('../endpoints/gptPlugins');
|
||||||
|
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
|
||||||
|
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils');
|
||||||
const {
|
const {
|
||||||
handleError,
|
handleAbort,
|
||||||
sendMessage,
|
createAbortController,
|
||||||
createOnProgress,
|
handleAbortError,
|
||||||
formatSteps,
|
setHeaders,
|
||||||
formatAction,
|
requireJwtAuth,
|
||||||
} = require('./handlers');
|
validateEndpoint,
|
||||||
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
|
buildEndpointOption,
|
||||||
|
} = require('../../middleware');
|
||||||
|
|
||||||
const abortControllers = new Map();
|
router.post('/abort', requireJwtAuth, handleAbort());
|
||||||
|
|
||||||
router.post('/abort', requireJwtAuth, async (req, res) => {
|
router.post(
|
||||||
try {
|
'/',
|
||||||
return await abortMessage(req, res, abortControllers);
|
requireJwtAuth,
|
||||||
} catch (err) {
|
validateEndpoint,
|
||||||
console.error(err);
|
buildEndpointOption,
|
||||||
}
|
setHeaders,
|
||||||
});
|
async (req, res) => {
|
||||||
|
let {
|
||||||
|
text,
|
||||||
|
endpointOption,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId = null,
|
||||||
|
overrideParentMessageId = null,
|
||||||
|
} = req.body;
|
||||||
|
console.log('ask log');
|
||||||
|
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||||
|
let metadata;
|
||||||
|
let userMessage;
|
||||||
|
let userMessageId;
|
||||||
|
let responseMessageId;
|
||||||
|
let lastSavedTimestamp = 0;
|
||||||
|
let saveDelay = 100;
|
||||||
|
const newConvo = !conversationId;
|
||||||
|
const user = req.user.id;
|
||||||
|
|
||||||
router.post('/', requireJwtAuth, async (req, res) => {
|
const plugin = {
|
||||||
const { endpoint, text, parentMessageId, conversationId } = req.body;
|
loading: true,
|
||||||
if (text.length === 0) {
|
inputs: [],
|
||||||
return handleError(res, { text: 'Prompt empty or too short' });
|
latest: null,
|
||||||
}
|
outputs: null,
|
||||||
if (endpoint !== 'gptPlugins') {
|
};
|
||||||
return handleError(res, { text: 'Illegal request' });
|
|
||||||
}
|
|
||||||
|
|
||||||
const agentOptions = req.body?.agentOptions ?? {
|
const addMetadata = (data) => (metadata = data);
|
||||||
agent: 'functions',
|
|
||||||
skipCompletion: true,
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
temperature: 0,
|
|
||||||
// top_p: 1,
|
|
||||||
// presence_penalty: 0,
|
|
||||||
// frequency_penalty: 0
|
|
||||||
};
|
|
||||||
|
|
||||||
const tools = req.body?.tools.map((tool) => tool.pluginKey) ?? [];
|
|
||||||
// build endpoint option
|
|
||||||
const endpointOption = {
|
|
||||||
chatGptLabel: tools.length === 0 ? req.body?.chatGptLabel ?? null : null,
|
|
||||||
promptPrefix: tools.length === 0 ? req.body?.promptPrefix ?? null : null,
|
|
||||||
tools,
|
|
||||||
modelOptions: {
|
|
||||||
model: req.body?.model ?? 'gpt-4',
|
|
||||||
temperature: req.body?.temperature ?? 0,
|
|
||||||
top_p: req.body?.top_p ?? 1,
|
|
||||||
presence_penalty: req.body?.presence_penalty ?? 0,
|
|
||||||
frequency_penalty: req.body?.frequency_penalty ?? 0,
|
|
||||||
},
|
|
||||||
agentOptions: {
|
|
||||||
...agentOptions,
|
|
||||||
// agent: 'functions'
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
console.log('ask log');
|
|
||||||
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
|
||||||
|
|
||||||
// eslint-disable-next-line no-use-before-define
|
|
||||||
return await ask({
|
|
||||||
text,
|
|
||||||
endpoint,
|
|
||||||
endpointOption,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
req,
|
|
||||||
res,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
const ask = async ({
|
|
||||||
text,
|
|
||||||
endpoint,
|
|
||||||
endpointOption,
|
|
||||||
parentMessageId = null,
|
|
||||||
conversationId,
|
|
||||||
req,
|
|
||||||
res,
|
|
||||||
}) => {
|
|
||||||
res.writeHead(200, {
|
|
||||||
Connection: 'keep-alive',
|
|
||||||
'Content-Type': 'text/event-stream',
|
|
||||||
'Cache-Control': 'no-cache, no-transform',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'X-Accel-Buffering': 'no',
|
|
||||||
});
|
|
||||||
let userMessage;
|
|
||||||
let userMessageId;
|
|
||||||
let responseMessageId;
|
|
||||||
let lastSavedTimestamp = 0;
|
|
||||||
const newConvo = !conversationId;
|
|
||||||
const { overrideParentMessageId = null } = req.body;
|
|
||||||
const user = req.user.id;
|
|
||||||
|
|
||||||
const plugin = {
|
|
||||||
loading: true,
|
|
||||||
inputs: [],
|
|
||||||
latest: null,
|
|
||||||
outputs: null,
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
const getIds = (data) => {
|
const getIds = (data) => {
|
||||||
userMessage = data.userMessage;
|
userMessage = data.userMessage;
|
||||||
userMessageId = userMessage.messageId;
|
userMessageId = userMessage.messageId;
|
||||||
|
|
@ -128,11 +72,11 @@ const ask = async ({
|
||||||
plugin.loading = false;
|
plugin.loading = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (currentTimestamp - lastSavedTimestamp > 500) {
|
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||||
lastSavedTimestamp = currentTimestamp;
|
lastSavedTimestamp = currentTimestamp;
|
||||||
saveMessage({
|
saveMessage({
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender: 'ChatGPT',
|
sender: getResponseSender(endpointOption),
|
||||||
conversationId,
|
conversationId,
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
text: partialText,
|
text: partialText,
|
||||||
|
|
@ -142,63 +86,13 @@ const ask = async ({
|
||||||
error: false,
|
error: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (saveDelay < 500) {
|
||||||
|
saveDelay = 500;
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const abortController = new AbortController();
|
|
||||||
abortController.abortAsk = async function () {
|
|
||||||
this.abort();
|
|
||||||
|
|
||||||
const responseMessage = {
|
|
||||||
messageId: responseMessageId,
|
|
||||||
sender: endpointOption?.chatGptLabel || 'ChatGPT',
|
|
||||||
conversationId,
|
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
text: getPartialText(),
|
|
||||||
plugin: { ...plugin, loading: false },
|
|
||||||
model: endpointOption.modelOptions.model,
|
|
||||||
unfinished: false,
|
|
||||||
cancelled: true,
|
|
||||||
error: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
saveMessage(responseMessage);
|
|
||||||
|
|
||||||
return {
|
|
||||||
title: await getConvoTitle(req.user.id, conversationId),
|
|
||||||
final: true,
|
|
||||||
conversation: await getConvo(req.user.id, conversationId),
|
|
||||||
requestMessage: userMessage,
|
|
||||||
responseMessage: responseMessage,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const onStart = (userMessage) => {
|
|
||||||
sendMessage(res, { message: userMessage, created: true });
|
|
||||||
abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption });
|
|
||||||
};
|
|
||||||
|
|
||||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
|
||||||
const clientOptions = {
|
|
||||||
debug: true,
|
|
||||||
endpoint,
|
|
||||||
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
|
|
||||||
proxy: process.env.PROXY || null,
|
|
||||||
...endpointOption,
|
|
||||||
};
|
|
||||||
|
|
||||||
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
|
|
||||||
if (process.env.PLUGINS_USE_AZURE) {
|
|
||||||
clientOptions.azure = getAzureCredentials();
|
|
||||||
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (openAIApiKey && openAIApiKey.includes('azure') && !clientOptions.azure) {
|
|
||||||
clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials();
|
|
||||||
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
|
|
||||||
}
|
|
||||||
const chatAgent = new PluginsClient(openAIApiKey, clientOptions);
|
|
||||||
|
|
||||||
const onAgentAction = (action, start = false) => {
|
const onAgentAction = (action, start = false) => {
|
||||||
const formattedAction = formatAction(action);
|
const formattedAction = formatAction(action);
|
||||||
plugin.inputs.push(formattedAction);
|
plugin.inputs.push(formattedAction);
|
||||||
|
|
@ -219,70 +113,86 @@ const ask = async ({
|
||||||
// console.log('CHAIN END', plugin.outputs);
|
// console.log('CHAIN END', plugin.outputs);
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = await chatAgent.sendMessage(text, {
|
const getAbortData = () => ({
|
||||||
getIds,
|
sender: getResponseSender(endpointOption),
|
||||||
user,
|
|
||||||
parentMessageId,
|
|
||||||
conversationId,
|
conversationId,
|
||||||
overrideParentMessageId,
|
messageId: responseMessageId,
|
||||||
onAgentAction,
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
onChainEnd,
|
text: getPartialText(),
|
||||||
onStart,
|
plugin: { ...plugin, loading: false },
|
||||||
...endpointOption,
|
userMessage,
|
||||||
onProgress: progressCallback.call(null, {
|
|
||||||
res,
|
|
||||||
text,
|
|
||||||
plugin,
|
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
}),
|
|
||||||
abortController,
|
|
||||||
});
|
});
|
||||||
|
const { abortController, onStart } = createAbortController(
|
||||||
|
res,
|
||||||
|
req,
|
||||||
|
endpointOption,
|
||||||
|
getAbortData,
|
||||||
|
);
|
||||||
|
|
||||||
if (overrideParentMessageId) {
|
try {
|
||||||
response.parentMessageId = overrideParentMessageId;
|
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||||
}
|
const { client, azure, openAIApiKey } = initializeClient(req, endpointOption);
|
||||||
|
|
||||||
console.log('CLIENT RESPONSE');
|
let response = await client.sendMessage(text, {
|
||||||
console.dir(response, { depth: null });
|
user,
|
||||||
response.plugin = { ...plugin, loading: false };
|
conversationId,
|
||||||
await saveMessage(response);
|
parentMessageId,
|
||||||
|
overrideParentMessageId,
|
||||||
|
getIds,
|
||||||
|
onAgentAction,
|
||||||
|
onChainEnd,
|
||||||
|
onStart,
|
||||||
|
addMetadata,
|
||||||
|
...endpointOption,
|
||||||
|
onProgress: progressCallback.call(null, {
|
||||||
|
res,
|
||||||
|
text,
|
||||||
|
plugin,
|
||||||
|
parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
|
}),
|
||||||
|
abortController,
|
||||||
|
});
|
||||||
|
|
||||||
sendMessage(res, {
|
if (overrideParentMessageId) {
|
||||||
title: await getConvoTitle(req.user.id, conversationId),
|
response.parentMessageId = overrideParentMessageId;
|
||||||
final: true,
|
}
|
||||||
conversation: await getConvo(req.user.id, conversationId),
|
|
||||||
requestMessage: userMessage,
|
|
||||||
responseMessage: response,
|
|
||||||
});
|
|
||||||
res.end();
|
|
||||||
|
|
||||||
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
|
if (metadata) {
|
||||||
const title = await titleConvo({
|
response = { ...response, ...metadata };
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('CLIENT RESPONSE');
|
||||||
|
console.dir(response, { depth: null });
|
||||||
|
response.plugin = { ...plugin, loading: false };
|
||||||
|
await saveMessage(response);
|
||||||
|
|
||||||
|
sendMessage(res, {
|
||||||
|
title: await getConvoTitle(req.user.id, conversationId),
|
||||||
|
final: true,
|
||||||
|
conversation: await getConvo(req.user.id, conversationId),
|
||||||
|
requestMessage: userMessage,
|
||||||
|
responseMessage: response,
|
||||||
|
});
|
||||||
|
res.end();
|
||||||
|
addTitle(req, {
|
||||||
text,
|
text,
|
||||||
|
newConvo,
|
||||||
response,
|
response,
|
||||||
openAIApiKey,
|
openAIApiKey,
|
||||||
azure: !!clientOptions.azure,
|
parentMessageId,
|
||||||
|
azure: !!azure,
|
||||||
});
|
});
|
||||||
await saveConvo(req.user.id, {
|
} catch (error) {
|
||||||
conversationId: conversationId,
|
const partialText = getPartialText();
|
||||||
title,
|
handleAbortError(res, req, error, {
|
||||||
|
partialText,
|
||||||
|
conversationId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
messageId: responseMessageId,
|
||||||
|
parentMessageId: userMessageId,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} catch (error) {
|
},
|
||||||
console.error(error);
|
);
|
||||||
const errorMessage = {
|
|
||||||
messageId: responseMessageId,
|
|
||||||
sender: 'ChatGPT',
|
|
||||||
conversationId,
|
|
||||||
parentMessageId: userMessageId,
|
|
||||||
unfinished: false,
|
|
||||||
cancelled: false,
|
|
||||||
error: true,
|
|
||||||
text: error.message,
|
|
||||||
};
|
|
||||||
await saveMessage(errorMessage);
|
|
||||||
handleError(res, errorMessage);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
// const askAzureOpenAI = require('./askAzureOpenAI';)
|
|
||||||
// const askOpenAI = require('./askOpenAI');
|
|
||||||
const openAI = require('./openAI');
|
const openAI = require('./openAI');
|
||||||
const google = require('./google');
|
const google = require('./google');
|
||||||
const bingAI = require('./bingAI');
|
const bingAI = require('./bingAI');
|
||||||
|
|
@ -9,7 +7,6 @@ const gptPlugins = require('./gptPlugins');
|
||||||
const askChatGPTBrowser = require('./askChatGPTBrowser');
|
const askChatGPTBrowser = require('./askChatGPTBrowser');
|
||||||
const anthropic = require('./anthropic');
|
const anthropic = require('./anthropic');
|
||||||
|
|
||||||
// router.use('/azureOpenAI', askAzureOpenAI);
|
|
||||||
router.use(['/azureOpenAI', '/openAI'], openAI);
|
router.use(['/azureOpenAI', '/openAI'], openAI);
|
||||||
router.use('/google', google);
|
router.use('/google', google);
|
||||||
router.use('/bingAI', bingAI);
|
router.use('/bingAI', bingAI);
|
||||||
|
|
|
||||||
|
|
@ -1,231 +1,160 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { titleConvo, OpenAIClient } = require('../../../app');
|
const { getResponseSender } = require('../endpoints/schemas');
|
||||||
const { getAzureCredentials, abortMessage } = require('../../../utils');
|
const { sendMessage, createOnProgress } = require('../../utils');
|
||||||
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
|
const { addTitle, initializeClient } = require('../endpoints/openAI');
|
||||||
const { handleError, sendMessage, createOnProgress } = require('./handlers');
|
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
|
||||||
const requireJwtAuth = require('../../../middleware/requireJwtAuth');
|
const {
|
||||||
|
handleAbort,
|
||||||
|
createAbortController,
|
||||||
|
handleAbortError,
|
||||||
|
setHeaders,
|
||||||
|
requireJwtAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
} = require('../../middleware');
|
||||||
|
|
||||||
const abortControllers = new Map();
|
router.post('/abort', requireJwtAuth, handleAbort());
|
||||||
|
|
||||||
router.post('/abort', requireJwtAuth, async (req, res) => {
|
router.post(
|
||||||
try {
|
'/',
|
||||||
return await abortMessage(req, res, abortControllers);
|
requireJwtAuth,
|
||||||
} catch (err) {
|
validateEndpoint,
|
||||||
console.error(err);
|
buildEndpointOption,
|
||||||
}
|
setHeaders,
|
||||||
});
|
async (req, res) => {
|
||||||
|
let {
|
||||||
|
text,
|
||||||
|
endpointOption,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId = null,
|
||||||
|
overrideParentMessageId = null,
|
||||||
|
} = req.body;
|
||||||
|
console.log('ask log');
|
||||||
|
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||||
|
let metadata;
|
||||||
|
let userMessage;
|
||||||
|
let userMessageId;
|
||||||
|
let responseMessageId;
|
||||||
|
let lastSavedTimestamp = 0;
|
||||||
|
let saveDelay = 100;
|
||||||
|
const newConvo = !conversationId;
|
||||||
|
const user = req.user.id;
|
||||||
|
|
||||||
router.post('/', requireJwtAuth, async (req, res) => {
|
const addMetadata = (data) => (metadata = data);
|
||||||
const { endpoint, text, parentMessageId, conversationId } = req.body;
|
|
||||||
if (text.length === 0) {
|
|
||||||
return handleError(res, { text: 'Prompt empty or too short' });
|
|
||||||
}
|
|
||||||
const isOpenAI = endpoint === 'openAI' || endpoint === 'azureOpenAI';
|
|
||||||
if (!isOpenAI) {
|
|
||||||
return handleError(res, { text: 'Illegal request' });
|
|
||||||
}
|
|
||||||
|
|
||||||
// build endpoint option
|
const getIds = (data) => {
|
||||||
const endpointOption = {
|
userMessage = data.userMessage;
|
||||||
chatGptLabel: req.body?.chatGptLabel ?? null,
|
userMessageId = userMessage.messageId;
|
||||||
promptPrefix: req.body?.promptPrefix ?? null,
|
responseMessageId = data.responseMessageId;
|
||||||
modelOptions: {
|
if (!conversationId) {
|
||||||
model: req.body?.model ?? 'gpt-3.5-turbo',
|
conversationId = data.conversationId;
|
||||||
temperature: req.body?.temperature ?? 1,
|
|
||||||
top_p: req.body?.top_p ?? 1,
|
|
||||||
presence_penalty: req.body?.presence_penalty ?? 0,
|
|
||||||
frequency_penalty: req.body?.frequency_penalty ?? 0,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
console.log('ask log');
|
|
||||||
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
|
||||||
|
|
||||||
// eslint-disable-next-line no-use-before-define
|
|
||||||
return await ask({
|
|
||||||
text,
|
|
||||||
endpointOption,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
endpoint,
|
|
||||||
req,
|
|
||||||
res,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
const ask = async ({
|
|
||||||
text,
|
|
||||||
endpointOption,
|
|
||||||
parentMessageId = null,
|
|
||||||
endpoint,
|
|
||||||
conversationId,
|
|
||||||
req,
|
|
||||||
res,
|
|
||||||
}) => {
|
|
||||||
res.writeHead(200, {
|
|
||||||
Connection: 'keep-alive',
|
|
||||||
'Content-Type': 'text/event-stream',
|
|
||||||
'Cache-Control': 'no-cache, no-transform',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'X-Accel-Buffering': 'no',
|
|
||||||
});
|
|
||||||
let userMessage;
|
|
||||||
let userMessageId;
|
|
||||||
let responseMessageId;
|
|
||||||
let lastSavedTimestamp = 0;
|
|
||||||
const newConvo = !conversationId;
|
|
||||||
const { overrideParentMessageId = null } = req.body;
|
|
||||||
const user = req.user.id;
|
|
||||||
|
|
||||||
const getIds = (data) => {
|
|
||||||
userMessage = data.userMessage;
|
|
||||||
userMessageId = userMessage.messageId;
|
|
||||||
responseMessageId = data.responseMessageId;
|
|
||||||
if (!conversationId) {
|
|
||||||
conversationId = data.conversationId;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
|
||||||
onProgress: ({ text: partialText }) => {
|
|
||||||
const currentTimestamp = Date.now();
|
|
||||||
|
|
||||||
if (currentTimestamp - lastSavedTimestamp > 500) {
|
|
||||||
lastSavedTimestamp = currentTimestamp;
|
|
||||||
saveMessage({
|
|
||||||
messageId: responseMessageId,
|
|
||||||
sender: 'ChatGPT',
|
|
||||||
conversationId,
|
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
text: partialText,
|
|
||||||
model: endpointOption.modelOptions.model,
|
|
||||||
unfinished: true,
|
|
||||||
cancelled: false,
|
|
||||||
error: false,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
},
|
};
|
||||||
});
|
|
||||||
|
|
||||||
const abortController = new AbortController();
|
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||||
abortController.abortAsk = async function () {
|
onProgress: ({ text: partialText }) => {
|
||||||
this.abort();
|
const currentTimestamp = Date.now();
|
||||||
|
|
||||||
const responseMessage = {
|
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||||
|
lastSavedTimestamp = currentTimestamp;
|
||||||
|
saveMessage({
|
||||||
|
messageId: responseMessageId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
conversationId,
|
||||||
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
|
text: partialText,
|
||||||
|
model: endpointOption.modelOptions.model,
|
||||||
|
unfinished: true,
|
||||||
|
cancelled: false,
|
||||||
|
error: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (saveDelay < 500) {
|
||||||
|
saveDelay = 500;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const getAbortData = () => ({
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
conversationId,
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender: endpointOption?.chatGptLabel || 'ChatGPT',
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
conversationId,
|
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
text: getPartialText(),
|
text: getPartialText(),
|
||||||
model: endpointOption.modelOptions.model,
|
userMessage,
|
||||||
unfinished: false,
|
|
||||||
cancelled: true,
|
|
||||||
error: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
saveMessage(responseMessage);
|
|
||||||
|
|
||||||
return {
|
|
||||||
title: await getConvoTitle(req.user.id, conversationId),
|
|
||||||
final: true,
|
|
||||||
conversation: await getConvo(req.user.id, conversationId),
|
|
||||||
requestMessage: userMessage,
|
|
||||||
responseMessage: responseMessage,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const onStart = (userMessage) => {
|
|
||||||
sendMessage(res, { message: userMessage, created: true });
|
|
||||||
abortControllers.set(userMessage.conversationId, { abortController, ...endpointOption });
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
const clientOptions = {
|
|
||||||
// debug: true,
|
|
||||||
// contextStrategy: 'refine',
|
|
||||||
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
|
|
||||||
proxy: process.env.PROXY || null,
|
|
||||||
endpoint,
|
|
||||||
...endpointOption,
|
|
||||||
};
|
|
||||||
|
|
||||||
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
|
|
||||||
|
|
||||||
if (process.env.AZURE_API_KEY && endpoint === 'azureOpenAI') {
|
|
||||||
clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials();
|
|
||||||
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
|
|
||||||
}
|
|
||||||
|
|
||||||
const client = new OpenAIClient(openAIApiKey, clientOptions);
|
|
||||||
|
|
||||||
let response = await client.sendMessage(text, {
|
|
||||||
user,
|
|
||||||
parentMessageId,
|
|
||||||
conversationId,
|
|
||||||
overrideParentMessageId,
|
|
||||||
getIds,
|
|
||||||
onStart,
|
|
||||||
onProgress: progressCallback.call(null, {
|
|
||||||
res,
|
|
||||||
text,
|
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
}),
|
|
||||||
abortController,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (overrideParentMessageId) {
|
const { abortController, onStart } = createAbortController(
|
||||||
response.parentMessageId = overrideParentMessageId;
|
res,
|
||||||
}
|
req,
|
||||||
|
endpointOption,
|
||||||
console.log(
|
getAbortData,
|
||||||
'promptTokens, completionTokens:',
|
|
||||||
response.promptTokens,
|
|
||||||
response.completionTokens,
|
|
||||||
);
|
);
|
||||||
await saveMessage(response);
|
|
||||||
|
|
||||||
sendMessage(res, {
|
try {
|
||||||
title: await getConvoTitle(req.user.id, conversationId),
|
const { client, openAIApiKey } = initializeClient(req, endpointOption);
|
||||||
final: true,
|
|
||||||
conversation: await getConvo(req.user.id, conversationId),
|
|
||||||
requestMessage: userMessage,
|
|
||||||
responseMessage: response,
|
|
||||||
});
|
|
||||||
res.end();
|
|
||||||
|
|
||||||
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
|
let response = await client.sendMessage(text, {
|
||||||
const title = await titleConvo({
|
user,
|
||||||
|
parentMessageId,
|
||||||
|
conversationId,
|
||||||
|
overrideParentMessageId,
|
||||||
|
getIds,
|
||||||
|
onStart,
|
||||||
|
addMetadata,
|
||||||
|
abortController,
|
||||||
|
onProgress: progressCallback.call(null, {
|
||||||
|
res,
|
||||||
|
text,
|
||||||
|
parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (overrideParentMessageId) {
|
||||||
|
response.parentMessageId = overrideParentMessageId;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (metadata) {
|
||||||
|
response = { ...response, ...metadata };
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
'promptTokens, completionTokens:',
|
||||||
|
response.promptTokens,
|
||||||
|
response.completionTokens,
|
||||||
|
);
|
||||||
|
await saveMessage(response);
|
||||||
|
|
||||||
|
sendMessage(res, {
|
||||||
|
title: await getConvoTitle(req.user.id, conversationId),
|
||||||
|
final: true,
|
||||||
|
conversation: await getConvo(req.user.id, conversationId),
|
||||||
|
requestMessage: userMessage,
|
||||||
|
responseMessage: response,
|
||||||
|
});
|
||||||
|
res.end();
|
||||||
|
|
||||||
|
addTitle(req, {
|
||||||
text,
|
text,
|
||||||
|
newConvo,
|
||||||
response,
|
response,
|
||||||
openAIApiKey,
|
openAIApiKey,
|
||||||
azure: endpoint === 'azureOpenAI',
|
parentMessageId,
|
||||||
|
azure: endpointOption.endpoint === 'azureOpenAI',
|
||||||
});
|
});
|
||||||
await saveConvo(req.user.id, {
|
} catch (error) {
|
||||||
|
const partialText = getPartialText();
|
||||||
|
handleAbortError(res, req, error, {
|
||||||
|
partialText,
|
||||||
conversationId,
|
conversationId,
|
||||||
title,
|
sender: getResponseSender(endpointOption),
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
|
||||||
const partialText = getPartialText();
|
|
||||||
if (partialText?.length > 2) {
|
|
||||||
return await abortMessage(req, res, abortControllers);
|
|
||||||
} else {
|
|
||||||
const errorMessage = {
|
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender: 'ChatGPT',
|
|
||||||
conversationId,
|
|
||||||
parentMessageId: userMessageId,
|
parentMessageId: userMessageId,
|
||||||
unfinished: false,
|
});
|
||||||
cancelled: false,
|
|
||||||
error: true,
|
|
||||||
text: error.message,
|
|
||||||
};
|
|
||||||
await saveMessage(errorMessage);
|
|
||||||
handleError(res, errorMessage);
|
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
};
|
);
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,7 @@ const {
|
||||||
} = require('../controllers/AuthController');
|
} = require('../controllers/AuthController');
|
||||||
const { loginController } = require('../controllers/auth/LoginController');
|
const { loginController } = require('../controllers/auth/LoginController');
|
||||||
const { logoutController } = require('../controllers/auth/LogoutController');
|
const { logoutController } = require('../controllers/auth/LogoutController');
|
||||||
const requireJwtAuth = require('../../middleware/requireJwtAuth');
|
const { requireJwtAuth, requireLocalAuth, validateRegistration } = require('../middleware');
|
||||||
const requireLocalAuth = require('../../middleware/requireLocalAuth');
|
|
||||||
const validateRegistration = require('../../middleware/validateRegistration');
|
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { getConvo, saveConvo } = require('../../models');
|
const { getConvo, saveConvo } = require('../../models');
|
||||||
const { getConvosByPage, deleteConvos } = require('../../models/Conversation');
|
const { getConvosByPage, deleteConvos } = require('../../models/Conversation');
|
||||||
const requireJwtAuth = require('../../middleware/requireJwtAuth');
|
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
||||||
|
|
||||||
router.get('/', requireJwtAuth, async (req, res) => {
|
router.get('/', requireJwtAuth, async (req, res) => {
|
||||||
const pageNumber = req.query.pageNumber || 1;
|
const pageNumber = req.query.pageNumber || 1;
|
||||||
|
|
|
||||||
139
api/server/routes/edit/anthropic.js
Normal file
139
api/server/routes/edit/anthropic.js
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
const express = require('express');
|
||||||
|
const router = express.Router();
|
||||||
|
const { getResponseSender } = require('../endpoints/schemas');
|
||||||
|
const { initializeClient } = require('../endpoints/anthropic');
|
||||||
|
const {
|
||||||
|
handleAbort,
|
||||||
|
createAbortController,
|
||||||
|
handleAbortError,
|
||||||
|
setHeaders,
|
||||||
|
requireJwtAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
} = require('../../middleware');
|
||||||
|
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
|
||||||
|
const { sendMessage, createOnProgress } = require('../../utils');
|
||||||
|
|
||||||
|
router.post('/abort', requireJwtAuth, handleAbort());
|
||||||
|
|
||||||
|
router.post(
|
||||||
|
'/',
|
||||||
|
requireJwtAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
setHeaders,
|
||||||
|
async (req, res) => {
|
||||||
|
let {
|
||||||
|
text,
|
||||||
|
generation,
|
||||||
|
endpointOption,
|
||||||
|
conversationId,
|
||||||
|
responseMessageId,
|
||||||
|
parentMessageId = null,
|
||||||
|
overrideParentMessageId = null,
|
||||||
|
} = req.body;
|
||||||
|
console.log('edit log');
|
||||||
|
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||||
|
let metadata;
|
||||||
|
let userMessage;
|
||||||
|
let lastSavedTimestamp = 0;
|
||||||
|
let saveDelay = 100;
|
||||||
|
const userMessageId = parentMessageId;
|
||||||
|
|
||||||
|
const addMetadata = (data) => (metadata = data);
|
||||||
|
const getIds = (data) => (userMessage = data.userMessage);
|
||||||
|
|
||||||
|
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||||
|
generation,
|
||||||
|
onProgress: ({ text: partialText }) => {
|
||||||
|
const currentTimestamp = Date.now();
|
||||||
|
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||||
|
lastSavedTimestamp = currentTimestamp;
|
||||||
|
saveMessage({
|
||||||
|
messageId: responseMessageId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
conversationId,
|
||||||
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
|
text: partialText,
|
||||||
|
unfinished: true,
|
||||||
|
cancelled: false,
|
||||||
|
error: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (saveDelay < 500) {
|
||||||
|
saveDelay = 500;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
try {
|
||||||
|
const getAbortData = () => ({
|
||||||
|
conversationId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
|
text: getPartialText(),
|
||||||
|
userMessage,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { abortController, onStart } = createAbortController(
|
||||||
|
res,
|
||||||
|
req,
|
||||||
|
endpointOption,
|
||||||
|
getAbortData,
|
||||||
|
);
|
||||||
|
|
||||||
|
const { client } = initializeClient(req, endpointOption);
|
||||||
|
|
||||||
|
let response = await client.sendMessage(text, {
|
||||||
|
user: req.user.id,
|
||||||
|
isEdited: true,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId,
|
||||||
|
responseMessageId,
|
||||||
|
overrideParentMessageId,
|
||||||
|
...endpointOption,
|
||||||
|
onProgress: progressCallback.call(null, {
|
||||||
|
res,
|
||||||
|
text,
|
||||||
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
|
}),
|
||||||
|
getIds,
|
||||||
|
onStart,
|
||||||
|
addMetadata,
|
||||||
|
abortController,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (metadata) {
|
||||||
|
response = { ...response, ...metadata };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (overrideParentMessageId) {
|
||||||
|
response.parentMessageId = overrideParentMessageId;
|
||||||
|
}
|
||||||
|
|
||||||
|
await saveMessage(response);
|
||||||
|
sendMessage(res, {
|
||||||
|
title: await getConvoTitle(req.user.id, conversationId),
|
||||||
|
final: true,
|
||||||
|
conversation: await getConvo(req.user.id, conversationId),
|
||||||
|
requestMessage: userMessage,
|
||||||
|
responseMessage: response,
|
||||||
|
});
|
||||||
|
res.end();
|
||||||
|
|
||||||
|
// TODO: add anthropic titling
|
||||||
|
} catch (error) {
|
||||||
|
const partialText = getPartialText();
|
||||||
|
handleAbortError(res, req, error, {
|
||||||
|
partialText,
|
||||||
|
conversationId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
messageId: responseMessageId,
|
||||||
|
parentMessageId: userMessageId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
185
api/server/routes/edit/gptPlugins.js
Normal file
185
api/server/routes/edit/gptPlugins.js
Normal file
|
|
@ -0,0 +1,185 @@
|
||||||
|
const express = require('express');
|
||||||
|
const router = express.Router();
|
||||||
|
const { getResponseSender } = require('../endpoints/schemas');
|
||||||
|
const { validateTools } = require('../../../app');
|
||||||
|
const { initializeClient } = require('../endpoints/gptPlugins');
|
||||||
|
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
|
||||||
|
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils');
|
||||||
|
const {
|
||||||
|
handleAbort,
|
||||||
|
createAbortController,
|
||||||
|
handleAbortError,
|
||||||
|
setHeaders,
|
||||||
|
requireJwtAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
} = require('../../middleware');
|
||||||
|
|
||||||
|
router.post('/abort', requireJwtAuth, handleAbort());
|
||||||
|
|
||||||
|
router.post(
|
||||||
|
'/',
|
||||||
|
requireJwtAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
setHeaders,
|
||||||
|
async (req, res) => {
|
||||||
|
let {
|
||||||
|
text,
|
||||||
|
generation,
|
||||||
|
endpointOption,
|
||||||
|
conversationId,
|
||||||
|
responseMessageId,
|
||||||
|
parentMessageId = null,
|
||||||
|
overrideParentMessageId = null,
|
||||||
|
} = req.body;
|
||||||
|
console.log('edit log');
|
||||||
|
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||||
|
let metadata;
|
||||||
|
let userMessage;
|
||||||
|
let lastSavedTimestamp = 0;
|
||||||
|
let saveDelay = 100;
|
||||||
|
const userMessageId = parentMessageId;
|
||||||
|
const user = req.user.id;
|
||||||
|
|
||||||
|
const plugin = {
|
||||||
|
loading: true,
|
||||||
|
inputs: [],
|
||||||
|
latest: null,
|
||||||
|
outputs: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
const addMetadata = (data) => (metadata = data);
|
||||||
|
const getIds = (data) => (userMessage = data.userMessage);
|
||||||
|
|
||||||
|
const {
|
||||||
|
onProgress: progressCallback,
|
||||||
|
sendIntermediateMessage,
|
||||||
|
getPartialText,
|
||||||
|
} = createOnProgress({
|
||||||
|
generation,
|
||||||
|
onProgress: ({ text: partialText }) => {
|
||||||
|
const currentTimestamp = Date.now();
|
||||||
|
|
||||||
|
if (plugin.loading === true) {
|
||||||
|
plugin.loading = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||||
|
lastSavedTimestamp = currentTimestamp;
|
||||||
|
saveMessage({
|
||||||
|
messageId: responseMessageId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
conversationId,
|
||||||
|
parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
|
text: partialText,
|
||||||
|
model: endpointOption.modelOptions.model,
|
||||||
|
unfinished: true,
|
||||||
|
cancelled: false,
|
||||||
|
error: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (saveDelay < 500) {
|
||||||
|
saveDelay = 500;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const onAgentAction = (action, start = false) => {
|
||||||
|
const formattedAction = formatAction(action);
|
||||||
|
plugin.inputs.push(formattedAction);
|
||||||
|
plugin.latest = formattedAction.plugin;
|
||||||
|
if (!start) {
|
||||||
|
saveMessage(userMessage);
|
||||||
|
}
|
||||||
|
sendIntermediateMessage(res, { plugin });
|
||||||
|
// console.log('PLUGIN ACTION', formattedAction);
|
||||||
|
};
|
||||||
|
|
||||||
|
const onChainEnd = (data) => {
|
||||||
|
let { intermediateSteps: steps } = data;
|
||||||
|
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
||||||
|
plugin.loading = false;
|
||||||
|
saveMessage(userMessage);
|
||||||
|
sendIntermediateMessage(res, { plugin });
|
||||||
|
// console.log('CHAIN END', plugin.outputs);
|
||||||
|
};
|
||||||
|
|
||||||
|
const getAbortData = () => ({
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
conversationId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
|
text: getPartialText(),
|
||||||
|
plugin: { ...plugin, loading: false },
|
||||||
|
userMessage,
|
||||||
|
});
|
||||||
|
const { abortController, onStart } = createAbortController(
|
||||||
|
res,
|
||||||
|
req,
|
||||||
|
endpointOption,
|
||||||
|
getAbortData,
|
||||||
|
);
|
||||||
|
|
||||||
|
try {
|
||||||
|
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||||
|
const { client } = initializeClient(req, endpointOption);
|
||||||
|
|
||||||
|
let response = await client.sendMessage(text, {
|
||||||
|
user,
|
||||||
|
isEdited: true,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId,
|
||||||
|
responseMessageId,
|
||||||
|
overrideParentMessageId,
|
||||||
|
getIds,
|
||||||
|
onAgentAction,
|
||||||
|
onChainEnd,
|
||||||
|
onStart,
|
||||||
|
addMetadata,
|
||||||
|
...endpointOption,
|
||||||
|
onProgress: progressCallback.call(null, {
|
||||||
|
res,
|
||||||
|
text,
|
||||||
|
plugin,
|
||||||
|
parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
|
}),
|
||||||
|
abortController,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (overrideParentMessageId) {
|
||||||
|
response.parentMessageId = overrideParentMessageId;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (metadata) {
|
||||||
|
response = { ...response, ...metadata };
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('CLIENT RESPONSE');
|
||||||
|
console.dir(response, { depth: null });
|
||||||
|
response.plugin = { ...plugin, loading: false };
|
||||||
|
await saveMessage(response);
|
||||||
|
|
||||||
|
sendMessage(res, {
|
||||||
|
title: await getConvoTitle(req.user.id, conversationId),
|
||||||
|
final: true,
|
||||||
|
conversation: await getConvo(req.user.id, conversationId),
|
||||||
|
requestMessage: userMessage,
|
||||||
|
responseMessage: response,
|
||||||
|
});
|
||||||
|
res.end();
|
||||||
|
} catch (error) {
|
||||||
|
const partialText = getPartialText();
|
||||||
|
handleAbortError(res, req, error, {
|
||||||
|
partialText,
|
||||||
|
conversationId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
messageId: responseMessageId,
|
||||||
|
parentMessageId: userMessageId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
13
api/server/routes/edit/index.js
Normal file
13
api/server/routes/edit/index.js
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
const express = require('express');
|
||||||
|
const router = express.Router();
|
||||||
|
const openAI = require('./openAI');
|
||||||
|
const gptPlugins = require('./gptPlugins');
|
||||||
|
const anthropic = require('./anthropic');
|
||||||
|
// const google = require('./google');
|
||||||
|
|
||||||
|
router.use(['/azureOpenAI', '/openAI'], openAI);
|
||||||
|
router.use('/gptPlugins', gptPlugins);
|
||||||
|
router.use('/anthropic', anthropic);
|
||||||
|
// router.use('/google', google);
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
141
api/server/routes/edit/openAI.js
Normal file
141
api/server/routes/edit/openAI.js
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
const express = require('express');
|
||||||
|
const router = express.Router();
|
||||||
|
const { getResponseSender } = require('../endpoints/schemas');
|
||||||
|
const { initializeClient } = require('../endpoints/openAI');
|
||||||
|
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
|
||||||
|
const { sendMessage, createOnProgress } = require('../../utils');
|
||||||
|
const {
|
||||||
|
handleAbort,
|
||||||
|
createAbortController,
|
||||||
|
handleAbortError,
|
||||||
|
setHeaders,
|
||||||
|
requireJwtAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
} = require('../../middleware');
|
||||||
|
|
||||||
|
router.post('/abort', requireJwtAuth, handleAbort());
|
||||||
|
|
||||||
|
router.post(
|
||||||
|
'/',
|
||||||
|
requireJwtAuth,
|
||||||
|
validateEndpoint,
|
||||||
|
buildEndpointOption,
|
||||||
|
setHeaders,
|
||||||
|
async (req, res) => {
|
||||||
|
let {
|
||||||
|
text,
|
||||||
|
generation,
|
||||||
|
endpointOption,
|
||||||
|
conversationId,
|
||||||
|
responseMessageId,
|
||||||
|
parentMessageId = null,
|
||||||
|
overrideParentMessageId = null,
|
||||||
|
} = req.body;
|
||||||
|
console.log('edit log');
|
||||||
|
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||||
|
let metadata;
|
||||||
|
let userMessage;
|
||||||
|
let lastSavedTimestamp = 0;
|
||||||
|
let saveDelay = 100;
|
||||||
|
const userMessageId = parentMessageId;
|
||||||
|
|
||||||
|
const addMetadata = (data) => (metadata = data);
|
||||||
|
const getIds = (data) => (userMessage = data.userMessage);
|
||||||
|
|
||||||
|
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||||
|
generation,
|
||||||
|
onProgress: ({ text: partialText }) => {
|
||||||
|
const currentTimestamp = Date.now();
|
||||||
|
|
||||||
|
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
|
||||||
|
lastSavedTimestamp = currentTimestamp;
|
||||||
|
saveMessage({
|
||||||
|
messageId: responseMessageId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
conversationId,
|
||||||
|
parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
|
text: partialText,
|
||||||
|
model: endpointOption.modelOptions.model,
|
||||||
|
unfinished: true,
|
||||||
|
cancelled: false,
|
||||||
|
error: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (saveDelay < 500) {
|
||||||
|
saveDelay = 500;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const getAbortData = () => ({
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
conversationId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
|
text: getPartialText(),
|
||||||
|
userMessage,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { abortController, onStart } = createAbortController(
|
||||||
|
res,
|
||||||
|
req,
|
||||||
|
endpointOption,
|
||||||
|
getAbortData,
|
||||||
|
);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { client } = initializeClient(req, endpointOption);
|
||||||
|
|
||||||
|
let response = await client.sendMessage(text, {
|
||||||
|
user: req.user.id,
|
||||||
|
isEdited: true,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId,
|
||||||
|
responseMessageId,
|
||||||
|
overrideParentMessageId,
|
||||||
|
getIds,
|
||||||
|
onStart,
|
||||||
|
addMetadata,
|
||||||
|
abortController,
|
||||||
|
onProgress: progressCallback.call(null, {
|
||||||
|
res,
|
||||||
|
text,
|
||||||
|
parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (metadata) {
|
||||||
|
response = { ...response, ...metadata };
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
'promptTokens, completionTokens:',
|
||||||
|
response.promptTokens,
|
||||||
|
response.completionTokens,
|
||||||
|
);
|
||||||
|
await saveMessage(response);
|
||||||
|
|
||||||
|
sendMessage(res, {
|
||||||
|
title: await getConvoTitle(req.user.id, conversationId),
|
||||||
|
final: true,
|
||||||
|
conversation: await getConvo(req.user.id, conversationId),
|
||||||
|
requestMessage: userMessage,
|
||||||
|
responseMessage: response,
|
||||||
|
});
|
||||||
|
res.end();
|
||||||
|
} catch (error) {
|
||||||
|
const partialText = getPartialText();
|
||||||
|
handleAbortError(res, req, error, {
|
||||||
|
partialText,
|
||||||
|
conversationId,
|
||||||
|
sender: getResponseSender(endpointOption),
|
||||||
|
messageId: responseMessageId,
|
||||||
|
parentMessageId: userMessageId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
|
|
@ -3,37 +3,45 @@ const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { availableTools } = require('../../app/clients/tools');
|
const { availableTools } = require('../../app/clients/tools');
|
||||||
const { addOpenAPISpecs } = require('../../app/clients/tools/util/addOpenAPISpecs');
|
const { addOpenAPISpecs } = require('../../app/clients/tools/util/addOpenAPISpecs');
|
||||||
|
// const { getAzureCredentials, genAzureChatCompletion } = require('../../utils/');
|
||||||
|
|
||||||
const openAIApiKey = process.env.OPENAI_API_KEY;
|
const openAIApiKey = process.env.OPENAI_API_KEY;
|
||||||
const azureOpenAIApiKey = process.env.AZURE_API_KEY;
|
const azureOpenAIApiKey = process.env.AZURE_API_KEY;
|
||||||
|
const useAzurePlugins = !!process.env.PLUGINS_USE_AZURE;
|
||||||
const userProvidedOpenAI = openAIApiKey
|
const userProvidedOpenAI = openAIApiKey
|
||||||
? openAIApiKey === 'user_provided'
|
? openAIApiKey === 'user_provided'
|
||||||
: azureOpenAIApiKey === 'user_provided';
|
: azureOpenAIApiKey === 'user_provided';
|
||||||
|
|
||||||
const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _models = []) => {
|
const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _models = []) => {
|
||||||
let models = _models.slice() ?? [];
|
let models = _models.slice() ?? [];
|
||||||
|
let apiKey = openAIApiKey;
|
||||||
|
let basePath = 'https://api.openai.com/v1';
|
||||||
if (opts.azure) {
|
if (opts.azure) {
|
||||||
/* TODO: Add Azure models from api/models */
|
|
||||||
return models;
|
return models;
|
||||||
|
// const azure = getAzureCredentials();
|
||||||
|
// basePath = (genAzureChatCompletion(azure))
|
||||||
|
// .split('/deployments')[0]
|
||||||
|
// .concat(`/models?api-version=${azure.azureOpenAIApiVersion}`);
|
||||||
|
// apiKey = azureOpenAIApiKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
let basePath = 'https://api.openai.com/v1/';
|
|
||||||
const reverseProxyUrl = process.env.OPENAI_REVERSE_PROXY;
|
const reverseProxyUrl = process.env.OPENAI_REVERSE_PROXY;
|
||||||
if (reverseProxyUrl) {
|
if (reverseProxyUrl) {
|
||||||
basePath = reverseProxyUrl.match(/.*v1/)[0];
|
basePath = reverseProxyUrl.match(/.*v1/)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (basePath.includes('v1')) {
|
if (basePath.includes('v1') || opts.azure) {
|
||||||
try {
|
try {
|
||||||
const res = await axios.get(`${basePath}/models`, {
|
const res = await axios.get(`${basePath}${opts.azure ? '' : '/models'}`, {
|
||||||
headers: {
|
headers: {
|
||||||
Authorization: `Bearer ${openAIApiKey}`,
|
Authorization: `Bearer ${apiKey}`,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
models = res.data.data.map((item) => item.id);
|
models = res.data.data.map((item) => item.id);
|
||||||
|
// console.log(`Fetched ${models.length} models from ${opts.azure ? 'Azure ' : ''}OpenAI API`);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
console.log(`Failed to fetch models from ${opts.azure ? 'Azure ' : ''}OpenAI API`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -149,7 +157,7 @@ router.get('/', async function (req, res) {
|
||||||
const gptPlugins =
|
const gptPlugins =
|
||||||
openAIApiKey || azureOpenAIApiKey
|
openAIApiKey || azureOpenAIApiKey
|
||||||
? {
|
? {
|
||||||
availableModels: await getOpenAIModels({ plugins: true }),
|
availableModels: await getOpenAIModels({ azure: useAzurePlugins, plugins: true }),
|
||||||
plugins,
|
plugins,
|
||||||
availableAgents: ['classic', 'functions'],
|
availableAgents: ['classic', 'functions'],
|
||||||
userProvide: userProvidedOpenAI,
|
userProvide: userProvidedOpenAI,
|
||||||
|
|
|
||||||
15
api/server/routes/endpoints/anthropic/buildOptions.js
Normal file
15
api/server/routes/endpoints/anthropic/buildOptions.js
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
const buildOptions = (endpoint, parsedBody) => {
|
||||||
|
const { modelLabel, promptPrefix, ...rest } = parsedBody;
|
||||||
|
const endpointOption = {
|
||||||
|
endpoint,
|
||||||
|
modelLabel,
|
||||||
|
promptPrefix,
|
||||||
|
modelOptions: {
|
||||||
|
...rest,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return endpointOption;
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = buildOptions;
|
||||||
8
api/server/routes/endpoints/anthropic/index.js
Normal file
8
api/server/routes/endpoints/anthropic/index.js
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
const buildOptions = require('./buildOptions');
|
||||||
|
const initializeClient = require('./initializeClient');
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
// addTitle, // todo
|
||||||
|
buildOptions,
|
||||||
|
initializeClient,
|
||||||
|
};
|
||||||
12
api/server/routes/endpoints/anthropic/initializeClient.js
Normal file
12
api/server/routes/endpoints/anthropic/initializeClient.js
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
const { AnthropicClient } = require('../../../../app');
|
||||||
|
|
||||||
|
const initializeClient = (req) => {
|
||||||
|
let anthropicApiKey = req.body?.token ?? process.env.ANTHROPIC_API_KEY;
|
||||||
|
const client = new AnthropicClient(anthropicApiKey);
|
||||||
|
return {
|
||||||
|
client,
|
||||||
|
anthropicApiKey,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = initializeClient;
|
||||||
31
api/server/routes/endpoints/gptPlugins/buildOptions.js
Normal file
31
api/server/routes/endpoints/gptPlugins/buildOptions.js
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
const buildOptions = (endpoint, parsedBody) => {
|
||||||
|
const {
|
||||||
|
chatGptLabel,
|
||||||
|
promptPrefix,
|
||||||
|
agentOptions,
|
||||||
|
tools,
|
||||||
|
model,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
presence_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
} = parsedBody;
|
||||||
|
const endpointOption = {
|
||||||
|
endpoint,
|
||||||
|
tools: tools.map((tool) => tool.pluginKey) ?? [],
|
||||||
|
chatGptLabel,
|
||||||
|
promptPrefix,
|
||||||
|
agentOptions,
|
||||||
|
modelOptions: {
|
||||||
|
model,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
presence_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return endpointOption;
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = buildOptions;
|
||||||
7
api/server/routes/endpoints/gptPlugins/index.js
Normal file
7
api/server/routes/endpoints/gptPlugins/index.js
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
const buildOptions = require('./buildOptions');
|
||||||
|
const initializeClient = require('./initializeClient');
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
buildOptions,
|
||||||
|
initializeClient,
|
||||||
|
};
|
||||||
30
api/server/routes/endpoints/gptPlugins/initializeClient.js
Normal file
30
api/server/routes/endpoints/gptPlugins/initializeClient.js
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
const { PluginsClient } = require('../../../../app');
|
||||||
|
const { getAzureCredentials } = require('../../../../utils');
|
||||||
|
|
||||||
|
const initializeClient = (req, endpointOption) => {
|
||||||
|
const clientOptions = {
|
||||||
|
debug: true,
|
||||||
|
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
|
||||||
|
proxy: process.env.PROXY || null,
|
||||||
|
...endpointOption,
|
||||||
|
};
|
||||||
|
|
||||||
|
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
|
||||||
|
if (process.env.PLUGINS_USE_AZURE) {
|
||||||
|
clientOptions.azure = getAzureCredentials();
|
||||||
|
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (openAIApiKey && openAIApiKey.includes('azure') && !clientOptions.azure) {
|
||||||
|
clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials();
|
||||||
|
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
|
||||||
|
}
|
||||||
|
const client = new PluginsClient(openAIApiKey, clientOptions);
|
||||||
|
return {
|
||||||
|
client,
|
||||||
|
azure: clientOptions.azure,
|
||||||
|
openAIApiKey,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = initializeClient;
|
||||||
22
api/server/routes/endpoints/openAI/addTitle.js
Normal file
22
api/server/routes/endpoints/openAI/addTitle.js
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
const { titleConvo } = require('../../../../app');
|
||||||
|
const { saveConvo } = require('../../../../models');
|
||||||
|
|
||||||
|
const addTitle = async (
|
||||||
|
req,
|
||||||
|
{ text, azure, response, newConvo, parentMessageId, openAIApiKey },
|
||||||
|
) => {
|
||||||
|
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
|
||||||
|
const title = await titleConvo({
|
||||||
|
text,
|
||||||
|
azure,
|
||||||
|
response,
|
||||||
|
openAIApiKey,
|
||||||
|
});
|
||||||
|
await saveConvo(req.user.id, {
|
||||||
|
conversationId: response.conversationId,
|
||||||
|
title,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = addTitle;
|
||||||
15
api/server/routes/endpoints/openAI/buildOptions.js
Normal file
15
api/server/routes/endpoints/openAI/buildOptions.js
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
const buildOptions = (endpoint, parsedBody) => {
|
||||||
|
const { chatGptLabel, promptPrefix, ...rest } = parsedBody;
|
||||||
|
const endpointOption = {
|
||||||
|
endpoint,
|
||||||
|
chatGptLabel,
|
||||||
|
promptPrefix,
|
||||||
|
modelOptions: {
|
||||||
|
...rest,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return endpointOption;
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = buildOptions;
|
||||||
9
api/server/routes/endpoints/openAI/index.js
Normal file
9
api/server/routes/endpoints/openAI/index.js
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
const addTitle = require('./addTitle');
|
||||||
|
const buildOptions = require('./buildOptions');
|
||||||
|
const initializeClient = require('./initializeClient');
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
addTitle,
|
||||||
|
buildOptions,
|
||||||
|
initializeClient,
|
||||||
|
};
|
||||||
27
api/server/routes/endpoints/openAI/initializeClient.js
Normal file
27
api/server/routes/endpoints/openAI/initializeClient.js
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
const { OpenAIClient } = require('../../../../app');
|
||||||
|
const { getAzureCredentials } = require('../../../../utils');
|
||||||
|
|
||||||
|
const initializeClient = (req, endpointOption) => {
|
||||||
|
const clientOptions = {
|
||||||
|
// debug: true,
|
||||||
|
// contextStrategy: 'refine',
|
||||||
|
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
|
||||||
|
proxy: process.env.PROXY || null,
|
||||||
|
...endpointOption,
|
||||||
|
};
|
||||||
|
|
||||||
|
let openAIApiKey = req.body?.token ?? process.env.OPENAI_API_KEY;
|
||||||
|
|
||||||
|
if (process.env.AZURE_API_KEY && endpointOption.endpoint === 'azureOpenAI') {
|
||||||
|
clientOptions.azure = JSON.parse(req.body?.token) ?? getAzureCredentials();
|
||||||
|
openAIApiKey = clientOptions.azure.azureOpenAIApiKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const client = new OpenAIClient(openAIApiKey, clientOptions);
|
||||||
|
return {
|
||||||
|
client,
|
||||||
|
openAIApiKey,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = initializeClient;
|
||||||
369
api/server/routes/endpoints/schemas.js
Normal file
369
api/server/routes/endpoints/schemas.js
Normal file
|
|
@ -0,0 +1,369 @@
|
||||||
|
const { z } = require('zod');
|
||||||
|
|
||||||
|
const EModelEndpoint = {
|
||||||
|
azureOpenAI: 'azureOpenAI',
|
||||||
|
openAI: 'openAI',
|
||||||
|
bingAI: 'bingAI',
|
||||||
|
chatGPTBrowser: 'chatGPTBrowser',
|
||||||
|
google: 'google',
|
||||||
|
gptPlugins: 'gptPlugins',
|
||||||
|
anthropic: 'anthropic',
|
||||||
|
};
|
||||||
|
|
||||||
|
const eModelEndpointSchema = z.nativeEnum(EModelEndpoint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
const tMessageSchema = z.object({
|
||||||
|
messageId: z.string(),
|
||||||
|
clientId: z.string().nullable().optional(),
|
||||||
|
conversationId: z.string().nullable(),
|
||||||
|
parentMessageId: z.string().nullable(),
|
||||||
|
sender: z.string(),
|
||||||
|
text: z.string(),
|
||||||
|
isCreatedByUser: z.boolean(),
|
||||||
|
error: z.boolean(),
|
||||||
|
createdAt: z
|
||||||
|
.string()
|
||||||
|
.optional()
|
||||||
|
.default(() => new Date().toISOString()),
|
||||||
|
updatedAt: z
|
||||||
|
.string()
|
||||||
|
.optional()
|
||||||
|
.default(() => new Date().toISOString()),
|
||||||
|
current: z.boolean().optional(),
|
||||||
|
unfinished: z.boolean().optional(),
|
||||||
|
submitting: z.boolean().optional(),
|
||||||
|
searchResult: z.boolean().optional(),
|
||||||
|
finish_reason: z.string().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const tPresetSchema = tConversationSchema
|
||||||
|
.omit({
|
||||||
|
conversationId: true,
|
||||||
|
createdAt: true,
|
||||||
|
updatedAt: true,
|
||||||
|
title: true,
|
||||||
|
})
|
||||||
|
.merge(
|
||||||
|
z.object({
|
||||||
|
conversationId: z.string().optional(),
|
||||||
|
presetId: z.string().nullable().optional(),
|
||||||
|
title: z.string().nullable().optional(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
*/
|
||||||
|
|
||||||
|
const tPluginAuthConfigSchema = z.object({
|
||||||
|
authField: z.string(),
|
||||||
|
label: z.string(),
|
||||||
|
description: z.string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const tPluginSchema = z.object({
|
||||||
|
name: z.string(),
|
||||||
|
pluginKey: z.string(),
|
||||||
|
description: z.string(),
|
||||||
|
icon: z.string(),
|
||||||
|
authConfig: z.array(tPluginAuthConfigSchema),
|
||||||
|
authenticated: z.boolean().optional(),
|
||||||
|
isButton: z.boolean().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const tExampleSchema = z.object({
|
||||||
|
input: z.object({
|
||||||
|
content: z.string(),
|
||||||
|
}),
|
||||||
|
output: z.object({
|
||||||
|
content: z.string(),
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const tAgentOptionsSchema = z.object({
|
||||||
|
agent: z.string(),
|
||||||
|
skipCompletion: z.boolean(),
|
||||||
|
model: z.string(),
|
||||||
|
temperature: z.number(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const tConversationSchema = z.object({
|
||||||
|
conversationId: z.string().nullable(),
|
||||||
|
title: z.string(),
|
||||||
|
user: z.string().optional(),
|
||||||
|
endpoint: eModelEndpointSchema.nullable(),
|
||||||
|
suggestions: z.array(z.string()).optional(),
|
||||||
|
messages: z.array(z.string()).optional(),
|
||||||
|
tools: z.array(tPluginSchema).optional(),
|
||||||
|
createdAt: z.string(),
|
||||||
|
updatedAt: z.string(),
|
||||||
|
systemMessage: z.string().nullable().optional(),
|
||||||
|
modelLabel: z.string().nullable().optional(),
|
||||||
|
examples: z.array(tExampleSchema).optional(),
|
||||||
|
chatGptLabel: z.string().nullable().optional(),
|
||||||
|
userLabel: z.string().optional(),
|
||||||
|
model: z.string().nullable().optional(),
|
||||||
|
promptPrefix: z.string().nullable().optional(),
|
||||||
|
temperature: z.number().optional(),
|
||||||
|
topP: z.number().optional(),
|
||||||
|
topK: z.number().optional(),
|
||||||
|
context: z.string().nullable().optional(),
|
||||||
|
top_p: z.number().optional(),
|
||||||
|
frequency_penalty: z.number().optional(),
|
||||||
|
presence_penalty: z.number().optional(),
|
||||||
|
jailbreak: z.boolean().optional(),
|
||||||
|
jailbreakConversationId: z.string().nullable().optional(),
|
||||||
|
conversationSignature: z.string().nullable().optional(),
|
||||||
|
parentMessageId: z.string().optional(),
|
||||||
|
clientId: z.string().nullable().optional(),
|
||||||
|
invocationId: z.number().nullable().optional(),
|
||||||
|
toneStyle: z.string().nullable().optional(),
|
||||||
|
maxOutputTokens: z.number().optional(),
|
||||||
|
agentOptions: tAgentOptionsSchema.nullable().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const openAISchema = tConversationSchema
|
||||||
|
.pick({
|
||||||
|
model: true,
|
||||||
|
chatGptLabel: true,
|
||||||
|
promptPrefix: true,
|
||||||
|
temperature: true,
|
||||||
|
top_p: true,
|
||||||
|
presence_penalty: true,
|
||||||
|
frequency_penalty: true,
|
||||||
|
})
|
||||||
|
.transform((obj) => ({
|
||||||
|
...obj,
|
||||||
|
model: obj.model ?? 'gpt-3.5-turbo',
|
||||||
|
chatGptLabel: obj.chatGptLabel ?? null,
|
||||||
|
promptPrefix: obj.promptPrefix ?? null,
|
||||||
|
temperature: obj.temperature ?? 1,
|
||||||
|
top_p: obj.top_p ?? 1,
|
||||||
|
presence_penalty: obj.presence_penalty ?? 0,
|
||||||
|
frequency_penalty: obj.frequency_penalty ?? 0,
|
||||||
|
}))
|
||||||
|
.catch(() => ({
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
chatGptLabel: null,
|
||||||
|
promptPrefix: null,
|
||||||
|
temperature: 1,
|
||||||
|
top_p: 1,
|
||||||
|
presence_penalty: 0,
|
||||||
|
frequency_penalty: 0,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const googleSchema = tConversationSchema
|
||||||
|
.pick({
|
||||||
|
model: true,
|
||||||
|
modelLabel: true,
|
||||||
|
promptPrefix: true,
|
||||||
|
examples: true,
|
||||||
|
temperature: true,
|
||||||
|
maxOutputTokens: true,
|
||||||
|
topP: true,
|
||||||
|
topK: true,
|
||||||
|
})
|
||||||
|
.transform((obj) => ({
|
||||||
|
...obj,
|
||||||
|
model: obj.model ?? 'chat-bison',
|
||||||
|
modelLabel: obj.modelLabel ?? null,
|
||||||
|
promptPrefix: obj.promptPrefix ?? null,
|
||||||
|
temperature: obj.temperature ?? 0.2,
|
||||||
|
maxOutputTokens: obj.maxOutputTokens ?? 1024,
|
||||||
|
topP: obj.topP ?? 0.95,
|
||||||
|
topK: obj.topK ?? 40,
|
||||||
|
}))
|
||||||
|
.catch(() => ({
|
||||||
|
model: 'chat-bison',
|
||||||
|
modelLabel: null,
|
||||||
|
promptPrefix: null,
|
||||||
|
temperature: 0.2,
|
||||||
|
maxOutputTokens: 1024,
|
||||||
|
topP: 0.95,
|
||||||
|
topK: 40,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const bingAISchema = tConversationSchema
|
||||||
|
.pick({
|
||||||
|
jailbreak: true,
|
||||||
|
systemMessage: true,
|
||||||
|
context: true,
|
||||||
|
toneStyle: true,
|
||||||
|
jailbreakConversationId: true,
|
||||||
|
conversationSignature: true,
|
||||||
|
clientId: true,
|
||||||
|
invocationId: true,
|
||||||
|
})
|
||||||
|
.transform((obj) => ({
|
||||||
|
...obj,
|
||||||
|
model: '',
|
||||||
|
jailbreak: obj.jailbreak ?? false,
|
||||||
|
systemMessage: obj.systemMessage ?? null,
|
||||||
|
context: obj.context ?? null,
|
||||||
|
toneStyle: obj.toneStyle ?? 'creative',
|
||||||
|
jailbreakConversationId: obj.jailbreakConversationId ?? null,
|
||||||
|
conversationSignature: obj.conversationSignature ?? null,
|
||||||
|
clientId: obj.clientId ?? null,
|
||||||
|
invocationId: obj.invocationId ?? 1,
|
||||||
|
}))
|
||||||
|
.catch(() => ({
|
||||||
|
model: '',
|
||||||
|
jailbreak: false,
|
||||||
|
systemMessage: null,
|
||||||
|
context: null,
|
||||||
|
toneStyle: 'creative',
|
||||||
|
jailbreakConversationId: null,
|
||||||
|
conversationSignature: null,
|
||||||
|
clientId: null,
|
||||||
|
invocationId: 1,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const anthropicSchema = tConversationSchema
|
||||||
|
.pick({
|
||||||
|
model: true,
|
||||||
|
modelLabel: true,
|
||||||
|
promptPrefix: true,
|
||||||
|
temperature: true,
|
||||||
|
maxOutputTokens: true,
|
||||||
|
topP: true,
|
||||||
|
topK: true,
|
||||||
|
})
|
||||||
|
.transform((obj) => ({
|
||||||
|
...obj,
|
||||||
|
model: obj.model ?? 'claude-1',
|
||||||
|
modelLabel: obj.modelLabel ?? null,
|
||||||
|
promptPrefix: obj.promptPrefix ?? null,
|
||||||
|
temperature: obj.temperature ?? 1,
|
||||||
|
maxOutputTokens: obj.maxOutputTokens ?? 1024,
|
||||||
|
topP: obj.topP ?? 0.7,
|
||||||
|
topK: obj.topK ?? 5,
|
||||||
|
}))
|
||||||
|
.catch(() => ({
|
||||||
|
model: 'claude-1',
|
||||||
|
modelLabel: null,
|
||||||
|
promptPrefix: null,
|
||||||
|
temperature: 1,
|
||||||
|
maxOutputTokens: 1024,
|
||||||
|
topP: 0.7,
|
||||||
|
topK: 5,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const chatGPTBrowserSchema = tConversationSchema
|
||||||
|
.pick({
|
||||||
|
model: true,
|
||||||
|
})
|
||||||
|
.transform((obj) => ({
|
||||||
|
...obj,
|
||||||
|
model: obj.model ?? 'text-davinci-002-render-sha',
|
||||||
|
}))
|
||||||
|
.catch(() => ({
|
||||||
|
model: 'text-davinci-002-render-sha',
|
||||||
|
}));
|
||||||
|
|
||||||
|
const gptPluginsSchema = tConversationSchema
|
||||||
|
.pick({
|
||||||
|
model: true,
|
||||||
|
chatGptLabel: true,
|
||||||
|
promptPrefix: true,
|
||||||
|
temperature: true,
|
||||||
|
top_p: true,
|
||||||
|
presence_penalty: true,
|
||||||
|
frequency_penalty: true,
|
||||||
|
tools: true,
|
||||||
|
agentOptions: true,
|
||||||
|
})
|
||||||
|
.transform((obj) => ({
|
||||||
|
...obj,
|
||||||
|
model: obj.model ?? 'gpt-3.5-turbo',
|
||||||
|
chatGptLabel: obj.chatGptLabel ?? null,
|
||||||
|
promptPrefix: obj.promptPrefix ?? null,
|
||||||
|
temperature: obj.temperature ?? 0.8,
|
||||||
|
top_p: obj.top_p ?? 1,
|
||||||
|
presence_penalty: obj.presence_penalty ?? 0,
|
||||||
|
frequency_penalty: obj.frequency_penalty ?? 0,
|
||||||
|
tools: obj.tools ?? [],
|
||||||
|
agentOptions: obj.agentOptions ?? {
|
||||||
|
agent: 'functions',
|
||||||
|
skipCompletion: true,
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
temperature: 0,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
.catch(() => ({
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
chatGptLabel: null,
|
||||||
|
promptPrefix: null,
|
||||||
|
temperature: 0.8,
|
||||||
|
top_p: 1,
|
||||||
|
presence_penalty: 0,
|
||||||
|
frequency_penalty: 0,
|
||||||
|
tools: [],
|
||||||
|
agentOptions: {
|
||||||
|
agent: 'functions',
|
||||||
|
skipCompletion: true,
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
temperature: 0,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
const endpointSchemas = {
|
||||||
|
openAI: openAISchema,
|
||||||
|
azureOpenAI: openAISchema,
|
||||||
|
google: googleSchema,
|
||||||
|
bingAI: bingAISchema,
|
||||||
|
anthropic: anthropicSchema,
|
||||||
|
chatGPTBrowser: chatGPTBrowserSchema,
|
||||||
|
gptPlugins: gptPluginsSchema,
|
||||||
|
};
|
||||||
|
|
||||||
|
function getFirstDefinedValue(possibleValues) {
|
||||||
|
let returnValue;
|
||||||
|
for (const value of possibleValues) {
|
||||||
|
if (value) {
|
||||||
|
returnValue = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return returnValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const parseConvo = (endpoint, conversation, possibleValues) => {
|
||||||
|
const schema = endpointSchemas[endpoint];
|
||||||
|
|
||||||
|
if (!schema) {
|
||||||
|
throw new Error(`Unknown endpoint: ${endpoint}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const convo = schema.parse(conversation);
|
||||||
|
|
||||||
|
if (possibleValues && convo) {
|
||||||
|
convo.model = getFirstDefinedValue(possibleValues.model) ?? convo.model;
|
||||||
|
}
|
||||||
|
|
||||||
|
return convo;
|
||||||
|
};
|
||||||
|
|
||||||
|
const getResponseSender = (endpointOption) => {
|
||||||
|
const { endpoint, chatGptLabel, modelLabel, jailbreak } = endpointOption;
|
||||||
|
|
||||||
|
if (['openAI', 'azureOpenAI', 'gptPlugins', 'chatGPTBrowser'].includes(endpoint)) {
|
||||||
|
return chatGptLabel ?? 'ChatGPT';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (endpoint === 'bingAI') {
|
||||||
|
return jailbreak ? 'Sydney' : 'BingAI';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (endpoint === 'anthropic') {
|
||||||
|
return modelLabel ?? 'Anthropic';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (endpoint === 'google') {
|
||||||
|
return modelLabel ?? 'PaLM2';
|
||||||
|
}
|
||||||
|
|
||||||
|
return '';
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
parseConvo,
|
||||||
|
getResponseSender,
|
||||||
|
};
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
const ask = require('./ask');
|
const ask = require('./ask');
|
||||||
|
const edit = require('./edit');
|
||||||
const messages = require('./messages');
|
const messages = require('./messages');
|
||||||
const convos = require('./convos');
|
const convos = require('./convos');
|
||||||
const presets = require('./presets');
|
const presets = require('./presets');
|
||||||
|
|
@ -15,6 +16,7 @@ const config = require('./config');
|
||||||
module.exports = {
|
module.exports = {
|
||||||
search,
|
search,
|
||||||
ask,
|
ask,
|
||||||
|
edit,
|
||||||
messages,
|
messages,
|
||||||
convos,
|
convos,
|
||||||
presets,
|
presets,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { getMessages } = require('../../models/Message');
|
const { getMessages } = require('../../models/Message');
|
||||||
const requireJwtAuth = require('../../middleware/requireJwtAuth');
|
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
||||||
|
|
||||||
router.get('/:conversationId', requireJwtAuth, async (req, res) => {
|
router.get('/:conversationId', requireJwtAuth, async (req, res) => {
|
||||||
const { conversationId } = req.params;
|
const { conversationId } = req.params;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { getAvailablePluginsController } = require('../controllers/PluginController');
|
const { getAvailablePluginsController } = require('../controllers/PluginController');
|
||||||
const requireJwtAuth = require('../../middleware/requireJwtAuth');
|
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { getPresets, savePreset, deletePresets } = require('../../models');
|
const { getPresets, savePreset, deletePresets } = require('../../models');
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
const requireJwtAuth = require('../../middleware/requireJwtAuth');
|
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
||||||
|
|
||||||
router.get('/', requireJwtAuth, async (req, res) => {
|
router.get('/', requireJwtAuth, async (req, res) => {
|
||||||
const presets = (await getPresets(req.user.id)).map((preset) => {
|
const presets = (await getPresets(req.user.id)).map((preset) => {
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ const { Message } = require('../../models/Message');
|
||||||
const { Conversation, getConvosQueried } = require('../../models/Conversation');
|
const { Conversation, getConvosQueried } = require('../../models/Conversation');
|
||||||
const { reduceHits } = require('../../lib/utils/reduceHits');
|
const { reduceHits } = require('../../lib/utils/reduceHits');
|
||||||
const { cleanUpPrimaryKeyValue } = require('../../lib/utils/misc');
|
const { cleanUpPrimaryKeyValue } = require('../../lib/utils/misc');
|
||||||
const requireJwtAuth = require('../../middleware/requireJwtAuth');
|
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
||||||
|
|
||||||
const cache = new Map();
|
const cache = new Map();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ const { Tiktoken } = require('@dqbd/tiktoken/lite');
|
||||||
const { load } = require('@dqbd/tiktoken/load');
|
const { load } = require('@dqbd/tiktoken/load');
|
||||||
const registry = require('@dqbd/tiktoken/registry.json');
|
const registry = require('@dqbd/tiktoken/registry.json');
|
||||||
const models = require('@dqbd/tiktoken/model_to_encoding.json');
|
const models = require('@dqbd/tiktoken/model_to_encoding.json');
|
||||||
const requireJwtAuth = require('../../middleware/requireJwtAuth');
|
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
||||||
|
|
||||||
router.post('/', requireJwtAuth, async (req, res) => {
|
router.post('/', requireJwtAuth, async (req, res) => {
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const requireJwtAuth = require('../../middleware/requireJwtAuth');
|
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
||||||
const { getUserController, updateUserPluginsController } = require('../controllers/UserController');
|
const { getUserController, updateUserPluginsController } = require('../controllers/UserController');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
const User = require('../../models/User');
|
|
||||||
const Token = require('../../models/schema/tokenSchema');
|
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
const bcrypt = require('bcryptjs');
|
const bcrypt = require('bcryptjs');
|
||||||
|
const User = require('../../models/User');
|
||||||
|
const Token = require('../../models/schema/tokenSchema');
|
||||||
const { registerSchema } = require('../../strategies/validators');
|
const { registerSchema } = require('../../strategies/validators');
|
||||||
const { sendEmail } = require('../../utils');
|
|
||||||
const config = require('../../../config/loader');
|
const config = require('../../../config/loader');
|
||||||
|
const { sendEmail } = require('../utils');
|
||||||
const domains = config.domains;
|
const domains = config.domains;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
const PluginAuth = require('../../models/schema/pluginAuthSchema');
|
const PluginAuth = require('../../models/schema/pluginAuthSchema');
|
||||||
const { encrypt, decrypt } = require('../../utils/');
|
const { encrypt, decrypt } = require('../utils/');
|
||||||
|
|
||||||
const getUserPluginAuthValue = async (user, authField) => {
|
const getUserPluginAuthValue = async (user, authField) => {
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,19 @@
|
||||||
const citationRegex = /\[\^\d+?\^\]/g;
|
const citationRegex = /\[\^\d+?\^\]/g;
|
||||||
|
const regex = / \[.*?]\(.*?\)/g;
|
||||||
|
|
||||||
|
const getCitations = (res) => {
|
||||||
|
const adaptiveCards = res.details.adaptiveCards;
|
||||||
|
const textBlocks = adaptiveCards && adaptiveCards[0].body;
|
||||||
|
if (!textBlocks) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
let links = textBlocks[textBlocks.length - 1]?.text.match(regex);
|
||||||
|
if (links?.length === 0 || !links) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
links = links.map((link) => link.trim());
|
||||||
|
return links.join('\n - ');
|
||||||
|
};
|
||||||
|
|
||||||
const citeText = (res, noLinks = false) => {
|
const citeText = (res, noLinks = false) => {
|
||||||
let result = res.text || res;
|
let result = res.text || res;
|
||||||
|
|
@ -32,4 +47,4 @@ const citeText = (res, noLinks = false) => {
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = citeText;
|
module.exports = { getCitations, citeText };
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
const _ = require('lodash');
|
const _ = require('lodash');
|
||||||
const citationRegex = /\[\^\d+?\^]/g;
|
const citationRegex = /\[\^\d+?\^]/g;
|
||||||
const { getCitations, citeText } = require('../../../app');
|
const { getCitations, citeText } = require('./citations');
|
||||||
const cursor = '<span className="result-streaming">█</span>';
|
const cursor = '<span className="result-streaming">█</span>';
|
||||||
|
|
||||||
|
const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text);
|
||||||
|
|
||||||
const handleError = (res, message) => {
|
const handleError = (res, message) => {
|
||||||
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
|
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
|
||||||
res.end();
|
res.end();
|
||||||
|
|
@ -15,12 +17,12 @@ const sendMessage = (res, message, event = 'message') => {
|
||||||
res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`);
|
res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`);
|
||||||
};
|
};
|
||||||
|
|
||||||
const createOnProgress = ({ onProgress: _onProgress }) => {
|
const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
|
||||||
let i = 0;
|
let i = 0;
|
||||||
let code = '';
|
let code = '';
|
||||||
let tokens = '';
|
|
||||||
let precode = '';
|
let precode = '';
|
||||||
let codeBlock = false;
|
let codeBlock = false;
|
||||||
|
let tokens = addSpaceIfNeeded(generation);
|
||||||
|
|
||||||
const progressCallback = async (partial, { res, text, plugin, bing = false, ...rest }) => {
|
const progressCallback = async (partial, { res, text, plugin, bing = false, ...rest }) => {
|
||||||
let chunk = partial === text ? '' : partial;
|
let chunk = partial === text ? '' : partial;
|
||||||
|
|
@ -155,4 +157,5 @@ module.exports = {
|
||||||
handleText,
|
handleText,
|
||||||
formatSteps,
|
formatSteps,
|
||||||
formatAction,
|
formatAction,
|
||||||
|
addSpaceIfNeeded,
|
||||||
};
|
};
|
||||||
11
api/server/utils/index.js
Normal file
11
api/server/utils/index.js
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
const cryptoUtils = require('./crypto');
|
||||||
|
const handleText = require('./handleText');
|
||||||
|
const citations = require('./citations');
|
||||||
|
const sendEmail = require('./sendEmail');
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
...cryptoUtils,
|
||||||
|
...handleText,
|
||||||
|
...citations,
|
||||||
|
sendEmail,
|
||||||
|
};
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
async function abortMessage(req, res, abortControllers) {
|
|
||||||
const { abortKey } = req.body;
|
|
||||||
console.log('req.body', req.body);
|
|
||||||
if (!abortControllers.has(abortKey)) {
|
|
||||||
return res.status(404).send('Request not found');
|
|
||||||
}
|
|
||||||
|
|
||||||
const { abortController } = abortControllers.get(abortKey);
|
|
||||||
|
|
||||||
abortControllers.delete(abortKey);
|
|
||||||
const ret = await abortController.abortAsk();
|
|
||||||
console.log('Aborted request', abortKey);
|
|
||||||
console.log('Aborted message:', ret);
|
|
||||||
|
|
||||||
res.send(JSON.stringify(ret));
|
|
||||||
}
|
|
||||||
|
|
||||||
module.exports = abortMessage;
|
|
||||||
|
|
@ -1,16 +1,10 @@
|
||||||
const azureUtils = require('./azureUtils');
|
const azureUtils = require('./azureUtils');
|
||||||
const cryptoUtils = require('./crypto');
|
|
||||||
const { tiktokenModels, maxTokensMap } = require('./tokens');
|
const { tiktokenModels, maxTokensMap } = require('./tokens');
|
||||||
const sendEmail = require('./sendEmail');
|
|
||||||
const abortMessage = require('./abortMessage');
|
|
||||||
const findMessageContent = require('./findMessageContent');
|
const findMessageContent = require('./findMessageContent');
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
...cryptoUtils,
|
|
||||||
...azureUtils,
|
...azureUtils,
|
||||||
maxTokensMap,
|
maxTokensMap,
|
||||||
tiktokenModels,
|
tiktokenModels,
|
||||||
sendEmail,
|
|
||||||
abortMessage,
|
|
||||||
findMessageContent,
|
findMessageContent,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -48,3 +48,17 @@ export type TSetOptionsPayload = {
|
||||||
checkPluginSelection: (value: string) => boolean;
|
checkPluginSelection: (value: string) => boolean;
|
||||||
setTools: (newValue: string) => void;
|
setTools: (newValue: string) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type TPresetItemProps = {
|
||||||
|
preset: TPreset;
|
||||||
|
value: TPreset;
|
||||||
|
onSelect: (preset: TPreset) => void;
|
||||||
|
onChangePreset: (preset: TPreset) => void;
|
||||||
|
onDeletePreset: (preset: TPreset) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type TOnClick = (e: React.MouseEvent<HTMLButtonElement>) => void;
|
||||||
|
|
||||||
|
export type TGenButtonProps = {
|
||||||
|
onClick: TOnClick;
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -103,10 +103,18 @@ export default function NewConversationMenu() {
|
||||||
};
|
};
|
||||||
|
|
||||||
// set the current model
|
// set the current model
|
||||||
|
const isModular = modularEndpoints.has(endpoint);
|
||||||
const onSelectPreset = (newPreset) => {
|
const onSelectPreset = (newPreset) => {
|
||||||
setMenuOpen(false);
|
setMenuOpen(false);
|
||||||
|
if (!newPreset) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (modularEndpoints.has(endpoint) && modularEndpoints.has(newPreset?.endpoint)) {
|
if (
|
||||||
|
isModular &&
|
||||||
|
modularEndpoints.has(newPreset?.endpoint) &&
|
||||||
|
endpoint === newPreset?.endpoint
|
||||||
|
) {
|
||||||
const currentConvo = getDefaultConversation({
|
const currentConvo = getDefaultConversation({
|
||||||
conversation,
|
conversation,
|
||||||
endpointsConfig,
|
endpointsConfig,
|
||||||
|
|
@ -118,10 +126,6 @@ export default function NewConversationMenu() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!newPreset) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
newConversation({}, newPreset);
|
newConversation({}, newPreset);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,14 @@
|
||||||
|
import type { TPresetItemProps } from '~/common';
|
||||||
|
import type { TPreset } from 'librechat-data-provider';
|
||||||
import { DropdownMenuRadioItem, EditIcon, TrashIcon } from '~/components';
|
import { DropdownMenuRadioItem, EditIcon, TrashIcon } from '~/components';
|
||||||
import { getIcon } from '~/components/Endpoints';
|
import { getIcon } from '~/components/Endpoints';
|
||||||
|
|
||||||
export default function PresetItem({ preset = {}, value, onChangePreset, onDeletePreset }) {
|
export default function PresetItem({
|
||||||
|
preset = {} as TPreset,
|
||||||
|
value,
|
||||||
|
onChangePreset,
|
||||||
|
onDeletePreset,
|
||||||
|
}: TPresetItemProps) {
|
||||||
const { endpoint } = preset;
|
const { endpoint } = preset;
|
||||||
|
|
||||||
const icon = getIcon({
|
const icon = getIcon({
|
||||||
|
|
@ -14,9 +21,9 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
|
||||||
|
|
||||||
const getPresetTitle = () => {
|
const getPresetTitle = () => {
|
||||||
let _title = `${endpoint}`;
|
let _title = `${endpoint}`;
|
||||||
|
const { chatGptLabel, modelLabel, model, jailbreak, toneStyle } = preset;
|
||||||
|
|
||||||
if (endpoint === 'azureOpenAI' || endpoint === 'openAI') {
|
if (endpoint === 'azureOpenAI' || endpoint === 'openAI') {
|
||||||
const { chatGptLabel, model } = preset;
|
|
||||||
if (model) {
|
if (model) {
|
||||||
_title += `: ${model}`;
|
_title += `: ${model}`;
|
||||||
}
|
}
|
||||||
|
|
@ -24,7 +31,6 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
|
||||||
_title += ` as ${chatGptLabel}`;
|
_title += ` as ${chatGptLabel}`;
|
||||||
}
|
}
|
||||||
} else if (endpoint === 'google') {
|
} else if (endpoint === 'google') {
|
||||||
const { modelLabel, model } = preset;
|
|
||||||
if (model) {
|
if (model) {
|
||||||
_title += `: ${model}`;
|
_title += `: ${model}`;
|
||||||
}
|
}
|
||||||
|
|
@ -32,7 +38,6 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
|
||||||
_title += ` as ${modelLabel}`;
|
_title += ` as ${modelLabel}`;
|
||||||
}
|
}
|
||||||
} else if (endpoint === 'bingAI') {
|
} else if (endpoint === 'bingAI') {
|
||||||
const { jailbreak, toneStyle } = preset;
|
|
||||||
if (toneStyle) {
|
if (toneStyle) {
|
||||||
_title += `: ${toneStyle}`;
|
_title += `: ${toneStyle}`;
|
||||||
}
|
}
|
||||||
|
|
@ -40,12 +45,10 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
|
||||||
_title += ' as Sydney';
|
_title += ' as Sydney';
|
||||||
}
|
}
|
||||||
} else if (endpoint === 'chatGPTBrowser') {
|
} else if (endpoint === 'chatGPTBrowser') {
|
||||||
const { model } = preset;
|
|
||||||
if (model) {
|
if (model) {
|
||||||
_title += `: ${model}`;
|
_title += `: ${model}`;
|
||||||
}
|
}
|
||||||
} else if (endpoint === 'gptPlugins') {
|
} else if (endpoint === 'gptPlugins') {
|
||||||
const { model } = preset;
|
|
||||||
if (model) {
|
if (model) {
|
||||||
_title += `: ${model}`;
|
_title += `: ${model}`;
|
||||||
}
|
}
|
||||||
|
|
@ -60,6 +63,7 @@ export default function PresetItem({ preset = {}, value, onChangePreset, onDelet
|
||||||
// regular model
|
// regular model
|
||||||
return (
|
return (
|
||||||
<DropdownMenuRadioItem
|
<DropdownMenuRadioItem
|
||||||
|
/* @ts-ignore, value can be an object as well */
|
||||||
value={value}
|
value={value}
|
||||||
className="group flex h-10 max-h-[44px] flex-row justify-between dark:font-semibold dark:text-gray-100 dark:hover:bg-gray-800 sm:h-auto"
|
className="group flex h-10 max-h-[44px] flex-row justify-between dark:font-semibold dark:text-gray-100 dark:hover:bg-gray-800 sm:h-auto"
|
||||||
>
|
>
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
import React from 'react';
|
import React from 'react';
|
||||||
import PresetItem from './PresetItem';
|
import PresetItem from './PresetItem';
|
||||||
|
import type { TPreset } from 'librechat-data-provider';
|
||||||
|
|
||||||
export default function PresetItems({ presets, onSelect, onChangePreset, onDeletePreset }) {
|
export default function PresetItems({ presets, onSelect, onChangePreset, onDeletePreset }) {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{presets.map((preset) => (
|
{presets.map((preset: TPreset) => (
|
||||||
<PresetItem
|
<PresetItem
|
||||||
key={preset?.presetId ?? Math.random()}
|
key={preset?.presetId ?? Math.random()}
|
||||||
value={preset}
|
value={preset}
|
||||||
27
client/src/components/Input/Generations/Button.tsx
Normal file
27
client/src/components/Input/Generations/Button.tsx
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
import { cn, removeFocusOutlines } from '~/utils/';
|
||||||
|
|
||||||
|
export default function Button({
|
||||||
|
type = 'regenerate',
|
||||||
|
children,
|
||||||
|
onClick,
|
||||||
|
className = '',
|
||||||
|
}: {
|
||||||
|
type?: 'regenerate' | 'continue' | 'stop';
|
||||||
|
children: React.ReactNode;
|
||||||
|
onClick: (e: React.MouseEvent<HTMLButtonElement>) => void;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
data-testid={`${type}-generation-button`}
|
||||||
|
className={cn(
|
||||||
|
'custom-btn btn-neutral relative -z-0 whitespace-nowrap border-0 md:border',
|
||||||
|
removeFocusOutlines,
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
onClick={onClick}
|
||||||
|
>
|
||||||
|
<div className="flex w-full items-center justify-center gap-2">{children}</div>
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
}
|
||||||
12
client/src/components/Input/Generations/Continue.tsx
Normal file
12
client/src/components/Input/Generations/Continue.tsx
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
import type { TGenButtonProps } from '~/common';
|
||||||
|
import { ContinueIcon } from '~/components/svg';
|
||||||
|
import Button from './Button';
|
||||||
|
|
||||||
|
export default function Continue({ onClick }: TGenButtonProps) {
|
||||||
|
return (
|
||||||
|
<Button type="continue" onClick={onClick}>
|
||||||
|
<ContinueIcon className="text-gray-600/90" />
|
||||||
|
Continue
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
import type { TMessage } from 'librechat-data-provider';
|
||||||
|
import { useMessageHandler, useMediaQuery, useGenerations } from '~/hooks';
|
||||||
|
import { cn } from '~/utils';
|
||||||
|
import Regenerate from './Regenerate';
|
||||||
|
import Continue from './Continue';
|
||||||
|
import Stop from './Stop';
|
||||||
|
|
||||||
|
type GenerationButtonsProps = {
|
||||||
|
endpoint: string;
|
||||||
|
showPopover: boolean;
|
||||||
|
opacityClass: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function GenerationButtons({
|
||||||
|
endpoint,
|
||||||
|
showPopover,
|
||||||
|
opacityClass,
|
||||||
|
}: GenerationButtonsProps) {
|
||||||
|
const {
|
||||||
|
messages,
|
||||||
|
isSubmitting,
|
||||||
|
latestMessage,
|
||||||
|
handleContinue,
|
||||||
|
handleRegenerate,
|
||||||
|
handleStopGenerating,
|
||||||
|
} = useMessageHandler();
|
||||||
|
const isSmallScreen = useMediaQuery('(max-width: 768px)');
|
||||||
|
const { continueSupported, regenerateEnabled } = useGenerations({
|
||||||
|
endpoint,
|
||||||
|
message: latestMessage as TMessage,
|
||||||
|
isSubmitting,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (isSmallScreen) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
let button: React.ReactNode = null;
|
||||||
|
|
||||||
|
if (isSubmitting) {
|
||||||
|
button = <Stop onClick={handleStopGenerating} />;
|
||||||
|
} else if (continueSupported) {
|
||||||
|
button = <Continue onClick={handleContinue} />;
|
||||||
|
} else if (messages && messages.length > 0 && regenerateEnabled) {
|
||||||
|
button = <Regenerate onClick={handleRegenerate} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="absolute bottom-4 right-0 z-[62]">
|
||||||
|
<div className="grow" />
|
||||||
|
<div className="flex items-center md:items-end">
|
||||||
|
<div
|
||||||
|
className={cn('option-buttons', showPopover ? '' : opacityClass)}
|
||||||
|
data-projection-id="173"
|
||||||
|
>
|
||||||
|
{button}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
12
client/src/components/Input/Generations/Regenerate.tsx
Normal file
12
client/src/components/Input/Generations/Regenerate.tsx
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
import type { TGenButtonProps } from '~/common';
|
||||||
|
import { RegenerateIcon } from '~/components/svg';
|
||||||
|
import Button from './Button';
|
||||||
|
|
||||||
|
export default function Regenerate({ onClick }: TGenButtonProps) {
|
||||||
|
return (
|
||||||
|
<Button onClick={onClick}>
|
||||||
|
<RegenerateIcon className="h-3 w-3 flex-shrink-0 text-gray-600/90" />
|
||||||
|
Regenerate
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
}
|
||||||
12
client/src/components/Input/Generations/Stop.tsx
Normal file
12
client/src/components/Input/Generations/Stop.tsx
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
import type { TGenButtonProps } from '~/common';
|
||||||
|
import { StopGeneratingIcon } from '~/components/svg';
|
||||||
|
import Button from './Button';
|
||||||
|
|
||||||
|
export default function Stop({ onClick }: TGenButtonProps) {
|
||||||
|
return (
|
||||||
|
<Button type="stop" onClick={onClick}>
|
||||||
|
<StopGeneratingIcon className="text-gray-600/90" />
|
||||||
|
Stop
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
}
|
||||||
1
client/src/components/Input/Generations/index.ts
Normal file
1
client/src/components/Input/Generations/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
export { default as GenerationButtons } from './GenerationButtons';
|
||||||
|
|
@ -12,7 +12,7 @@ import { Button } from '~/components/ui';
|
||||||
import { cn, cardStyle } from '~/utils/';
|
import { cn, cardStyle } from '~/utils/';
|
||||||
import { useSetOptions } from '~/hooks';
|
import { useSetOptions } from '~/hooks';
|
||||||
import { ModelSelect } from './ModelSelect';
|
import { ModelSelect } from './ModelSelect';
|
||||||
import GenerationButtons from './GenerationButtons';
|
import { GenerationButtons } from './Generations';
|
||||||
import store from '~/store';
|
import store from '~/store';
|
||||||
|
|
||||||
export default function OptionsBar() {
|
export default function OptionsBar() {
|
||||||
|
|
@ -76,7 +76,11 @@ export default function OptionsBar() {
|
||||||
: () => setShowPopover((prev) => !prev);
|
: () => setShowPopover((prev) => !prev);
|
||||||
return (
|
return (
|
||||||
<div className="relative py-2 last:mb-2 md:mx-4 md:mb-[-16px] md:py-4 md:pt-2 md:last:mb-6 lg:mx-auto lg:mb-[-32px] lg:max-w-2xl lg:pt-6 xl:max-w-3xl">
|
<div className="relative py-2 last:mb-2 md:mx-4 md:mb-[-16px] md:py-4 md:pt-2 md:last:mb-6 lg:mx-auto lg:mb-[-32px] lg:max-w-2xl lg:pt-6 xl:max-w-3xl">
|
||||||
<GenerationButtons showPopover={showPopover} opacityClass={opacityClass} />
|
<GenerationButtons
|
||||||
|
endpoint={endpoint}
|
||||||
|
showPopover={showPopover}
|
||||||
|
opacityClass={opacityClass}
|
||||||
|
/>
|
||||||
<span className="flex w-full flex-col items-center justify-center gap-0 md:order-none md:m-auto md:gap-2">
|
<span className="flex w-full flex-col items-center justify-center gap-0 md:order-none md:m-auto md:gap-2">
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
|
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
import React from 'react';
|
|
||||||
|
|
||||||
export default function RowButton({ onClick, children, text, className }) {
|
|
||||||
return (
|
|
||||||
<button
|
|
||||||
onClick={onClick}
|
|
||||||
className={`input-panel-button btn btn-neutral flex justify-center gap-2 border-0 md:border ${className}`}
|
|
||||||
type="button"
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
<span className="hidden md:block">{text}</span>
|
|
||||||
{/* <RegenerateIcon />
|
|
||||||
<span className="hidden md:block">Regenerate response</span> */}
|
|
||||||
</button>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
@ -10,22 +10,18 @@ import { cn } from '~/utils';
|
||||||
import store from '~/store';
|
import store from '~/store';
|
||||||
|
|
||||||
export default function TextChat({ isSearchView = false }) {
|
export default function TextChat({ isSearchView = false }) {
|
||||||
const inputRef = useRef(null);
|
const { ask, isSubmitting, handleStopGenerating, latestMessage, endpointsConfig } =
|
||||||
const isComposing = useRef(false);
|
useMessageHandler();
|
||||||
|
const conversation = useRecoilValue(store.conversation);
|
||||||
|
const setShowBingToneSetting = useSetRecoilState(store.showBingToneSetting);
|
||||||
const [text, setText] = useRecoilState(store.text);
|
const [text, setText] = useRecoilState(store.text);
|
||||||
const { theme } = useContext(ThemeContext);
|
const { theme } = useContext(ThemeContext);
|
||||||
const conversation = useRecoilValue(store.conversation);
|
const isComposing = useRef(false);
|
||||||
const latestMessage = useRecoilValue(store.latestMessage);
|
const inputRef = useRef(null);
|
||||||
|
|
||||||
const endpointsConfig = useRecoilValue(store.endpointsConfig);
|
|
||||||
const isSubmitting = useRecoilValue(store.isSubmitting);
|
|
||||||
const setShowBingToneSetting = useSetRecoilState(store.showBingToneSetting);
|
|
||||||
|
|
||||||
// TODO: do we need this?
|
// TODO: do we need this?
|
||||||
const disabled = false;
|
const disabled = false;
|
||||||
|
|
||||||
const { ask, stopGenerating } = useMessageHandler();
|
|
||||||
const isNotAppendable = latestMessage?.unfinished & !isSubmitting || latestMessage?.error;
|
const isNotAppendable = latestMessage?.unfinished & !isSubmitting || latestMessage?.error;
|
||||||
const { conversationId, jailbreak } = conversation || {};
|
const { conversationId, jailbreak } = conversation || {};
|
||||||
|
|
||||||
|
|
@ -60,11 +56,6 @@ export default function TextChat({ isSearchView = false }) {
|
||||||
setText('');
|
setText('');
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleStopGenerating = (e) => {
|
|
||||||
e.preventDefault();
|
|
||||||
stopGenerating();
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleKeyDown = (e) => {
|
const handleKeyDown = (e) => {
|
||||||
if (e.key === 'Enter' && isSubmitting) {
|
if (e.key === 'Enter' && isSubmitting) {
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -1,87 +0,0 @@
|
||||||
import React from 'react';
|
|
||||||
import { cn } from '~/utils/';
|
|
||||||
import Clipboard from '../svg/Clipboard';
|
|
||||||
import CheckMark from '../svg/CheckMark';
|
|
||||||
import EditIcon from '../svg/EditIcon';
|
|
||||||
import RegenerateIcon from '../svg/RegenerateIcon';
|
|
||||||
|
|
||||||
export default function HoverButtons({
|
|
||||||
isEditting,
|
|
||||||
enterEdit,
|
|
||||||
copyToClipboard,
|
|
||||||
conversation,
|
|
||||||
isSubmitting,
|
|
||||||
message,
|
|
||||||
regenerate,
|
|
||||||
}) {
|
|
||||||
const { endpoint } = conversation;
|
|
||||||
const [isCopied, setIsCopied] = React.useState(false);
|
|
||||||
|
|
||||||
const branchingSupported =
|
|
||||||
// azureOpenAI, openAI, chatGPTBrowser support branching, so edit enabled // 5/21/23: Bing is allowing editing and Message regenerating
|
|
||||||
!![
|
|
||||||
'azureOpenAI',
|
|
||||||
'openAI',
|
|
||||||
'chatGPTBrowser',
|
|
||||||
'google',
|
|
||||||
'bingAI',
|
|
||||||
'gptPlugins',
|
|
||||||
'anthropic',
|
|
||||||
].find((e) => e === endpoint);
|
|
||||||
// Sydney in bingAI supports branching, so edit enabled
|
|
||||||
|
|
||||||
const editEnabled =
|
|
||||||
!message?.error &&
|
|
||||||
message?.isCreatedByUser &&
|
|
||||||
!message?.searchResult &&
|
|
||||||
!isEditting &&
|
|
||||||
branchingSupported;
|
|
||||||
|
|
||||||
// for now, once branching is supported, regerate will be enabled
|
|
||||||
let regenerateEnabled =
|
|
||||||
// !message?.error &&
|
|
||||||
!message?.isCreatedByUser &&
|
|
||||||
!message?.searchResult &&
|
|
||||||
!isEditting &&
|
|
||||||
!isSubmitting &&
|
|
||||||
branchingSupported;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="visible mt-2 flex justify-center gap-3 self-end text-gray-400 md:gap-4 lg:absolute lg:right-0 lg:top-0 lg:mt-0 lg:translate-x-full lg:gap-1 lg:self-center lg:pl-2">
|
|
||||||
{editEnabled ? (
|
|
||||||
<button
|
|
||||||
className="hover-button rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible"
|
|
||||||
onClick={enterEdit}
|
|
||||||
type="button"
|
|
||||||
title="edit"
|
|
||||||
>
|
|
||||||
{/* <button className="rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"> */}
|
|
||||||
<EditIcon />
|
|
||||||
</button>
|
|
||||||
) : null}
|
|
||||||
{regenerateEnabled ? (
|
|
||||||
<button
|
|
||||||
className="hover-button active rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible"
|
|
||||||
onClick={regenerate}
|
|
||||||
type="button"
|
|
||||||
title="regenerate"
|
|
||||||
>
|
|
||||||
{/* <button className="rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"> */}
|
|
||||||
<RegenerateIcon />
|
|
||||||
</button>
|
|
||||||
) : null}
|
|
||||||
|
|
||||||
<button
|
|
||||||
className={cn(
|
|
||||||
'hover-button rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible',
|
|
||||||
message?.isCreatedByUser ? '' : 'active',
|
|
||||||
)}
|
|
||||||
onClick={() => copyToClipboard(setIsCopied)}
|
|
||||||
type="button"
|
|
||||||
title={isCopied ? 'Copied to clipboard' : 'Copy to clipboard'}
|
|
||||||
>
|
|
||||||
{isCopied ? <CheckMark /> : <Clipboard />}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
86
client/src/components/Messages/HoverButtons.tsx
Normal file
86
client/src/components/Messages/HoverButtons.tsx
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
import { useState } from 'react';
|
||||||
|
import type { TConversation, TMessage } from 'librechat-data-provider';
|
||||||
|
import { Clipboard, CheckMark, EditIcon, RegenerateIcon, ContinueIcon } from '~/components/svg';
|
||||||
|
import { useGenerations } from '~/hooks';
|
||||||
|
import { cn } from '~/utils';
|
||||||
|
|
||||||
|
type THoverButtons = {
|
||||||
|
isEditing: boolean;
|
||||||
|
enterEdit: () => void;
|
||||||
|
copyToClipboard: (setIsCopied: (isCopied: boolean) => void) => void;
|
||||||
|
conversation: TConversation;
|
||||||
|
isSubmitting: boolean;
|
||||||
|
message: TMessage;
|
||||||
|
regenerate: () => void;
|
||||||
|
handleContinue: (e: React.MouseEvent<HTMLButtonElement>) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function HoverButtons({
|
||||||
|
isEditing,
|
||||||
|
enterEdit,
|
||||||
|
copyToClipboard,
|
||||||
|
conversation,
|
||||||
|
isSubmitting,
|
||||||
|
message,
|
||||||
|
regenerate,
|
||||||
|
handleContinue,
|
||||||
|
}: THoverButtons) {
|
||||||
|
const { endpoint } = conversation;
|
||||||
|
const [isCopied, setIsCopied] = useState(false);
|
||||||
|
const { editEnabled, regenerateEnabled, continueSupported } = useGenerations({
|
||||||
|
isEditing,
|
||||||
|
isSubmitting,
|
||||||
|
message,
|
||||||
|
endpoint: endpoint ?? '',
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="visible mt-2 flex justify-center gap-3 self-end text-gray-400 md:gap-4 lg:absolute lg:right-0 lg:top-0 lg:mt-0 lg:translate-x-full lg:gap-1 lg:self-center lg:pl-2">
|
||||||
|
<button
|
||||||
|
className={cn(
|
||||||
|
'hover-button rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible',
|
||||||
|
message?.isCreatedByUser ? '' : 'opacity-0',
|
||||||
|
)}
|
||||||
|
onClick={enterEdit}
|
||||||
|
type="button"
|
||||||
|
title="edit"
|
||||||
|
disabled={!editEnabled}
|
||||||
|
>
|
||||||
|
{/* <button className="rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"> */}
|
||||||
|
<EditIcon />
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
className={cn(
|
||||||
|
'hover-button rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible',
|
||||||
|
message?.isCreatedByUser ? '' : 'active',
|
||||||
|
)}
|
||||||
|
onClick={() => copyToClipboard(setIsCopied)}
|
||||||
|
type="button"
|
||||||
|
title={isCopied ? 'Copied to clipboard' : 'Copy to clipboard'}
|
||||||
|
>
|
||||||
|
{isCopied ? <CheckMark /> : <Clipboard />}
|
||||||
|
</button>
|
||||||
|
{regenerateEnabled ? (
|
||||||
|
<button
|
||||||
|
className="hover-button active rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible"
|
||||||
|
onClick={regenerate}
|
||||||
|
type="button"
|
||||||
|
title="regenerate"
|
||||||
|
>
|
||||||
|
{/* <button className="rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"> */}
|
||||||
|
<RegenerateIcon />
|
||||||
|
</button>
|
||||||
|
) : null}
|
||||||
|
{continueSupported ? (
|
||||||
|
<button
|
||||||
|
className="hover-button active rounded-md p-1 hover:bg-gray-100 hover:text-gray-700 dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible "
|
||||||
|
onClick={handleContinue}
|
||||||
|
type="button"
|
||||||
|
title="continue"
|
||||||
|
>
|
||||||
|
<ContinueIcon className="h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
) : null}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
/* eslint-disable react-hooks/exhaustive-deps */
|
/* eslint-disable react-hooks/exhaustive-deps */
|
||||||
import { useState, useEffect, useRef } from 'react';
|
import { useState, useEffect, useRef } from 'react';
|
||||||
import { useRecoilValue, useSetRecoilState } from 'recoil';
|
import { useSetRecoilState } from 'recoil';
|
||||||
import copy from 'copy-to-clipboard';
|
import copy from 'copy-to-clipboard';
|
||||||
import Plugin from './Plugin';
|
import Plugin from './Plugin';
|
||||||
import SubRow from './Content/SubRow';
|
import SubRow from './Content/SubRow';
|
||||||
|
|
@ -25,13 +25,12 @@ export default function Message({
|
||||||
setSiblingIdx,
|
setSiblingIdx,
|
||||||
}) {
|
}) {
|
||||||
const { text, searchResult, isCreatedByUser, error, submitting, unfinished } = message;
|
const { text, searchResult, isCreatedByUser, error, submitting, unfinished } = message;
|
||||||
const isSubmitting = useRecoilValue(store.isSubmitting);
|
|
||||||
const setLatestMessage = useSetRecoilState(store.latestMessage);
|
const setLatestMessage = useSetRecoilState(store.latestMessage);
|
||||||
const [abortScroll, setAbort] = useState(false);
|
const [abortScroll, setAbort] = useState(false);
|
||||||
const textEditor = useRef(null);
|
const textEditor = useRef(null);
|
||||||
const last = !message?.children?.length;
|
const last = !message?.children?.length;
|
||||||
const edit = message.messageId == currentEditId;
|
const edit = message.messageId == currentEditId;
|
||||||
const { ask, regenerate } = useMessageHandler();
|
const { isSubmitting, ask, regenerate, handleContinue } = useMessageHandler();
|
||||||
const { switchToConversation } = store.useConversation();
|
const { switchToConversation } = store.useConversation();
|
||||||
const blinker = submitting && isSubmitting;
|
const blinker = submitting && isSubmitting;
|
||||||
const getConversationQuery = useGetConversationByIdQuery(message.conversationId, {
|
const getConversationQuery = useGetConversationByIdQuery(message.conversationId, {
|
||||||
|
|
@ -223,12 +222,13 @@ export default function Message({
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<HoverButtons
|
<HoverButtons
|
||||||
isEditting={edit}
|
isEditing={edit}
|
||||||
isSubmitting={isSubmitting}
|
isSubmitting={isSubmitting}
|
||||||
message={message}
|
message={message}
|
||||||
conversation={conversation}
|
conversation={conversation}
|
||||||
enterEdit={() => enterEdit()}
|
enterEdit={() => enterEdit()}
|
||||||
regenerate={() => regenerateMessage()}
|
regenerate={() => regenerateMessage()}
|
||||||
|
handleContinue={handleContinue}
|
||||||
copyToClipboard={copyToClipboard}
|
copyToClipboard={copyToClipboard}
|
||||||
/>
|
/>
|
||||||
<SubRow subclasses="switch-container">
|
<SubRow subclasses="switch-container">
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
import React from 'react';
|
type Props = {
|
||||||
|
scrollHandler: React.MouseEventHandler<HTMLButtonElement>;
|
||||||
|
};
|
||||||
|
|
||||||
export default function ScrollToBottom({ scrollHandler }) {
|
export default function ScrollToBottom({ scrollHandler }: Props) {
|
||||||
return (
|
return (
|
||||||
<button
|
<button
|
||||||
onClick={scrollHandler}
|
onClick={scrollHandler}
|
||||||
className="absolute bottom-[124px] right-6 z-[62] cursor-pointer rounded-full border border-gray-200 bg-gray-50 text-gray-600 dark:border-white/10 dark:bg-white/10 dark:text-gray-200 md:bottom-[120px]"
|
className="absolute bottom-[124px] right-6 z-[62] cursor-pointer rounded-full border border-gray-200 bg-gray-50 text-gray-600 dark:border-white/10 dark:bg-white/10 dark:text-gray-200 md:bottom-[180px] lg:bottom-[120px]"
|
||||||
>
|
>
|
||||||
<svg
|
<svg
|
||||||
stroke="currentColor"
|
stroke="currentColor"
|
||||||
|
|
@ -9,7 +9,7 @@ export default function Clipboard() {
|
||||||
viewBox="0 0 24 24"
|
viewBox="0 0 24 24"
|
||||||
strokeLinecap="round"
|
strokeLinecap="round"
|
||||||
strokeLinejoin="round"
|
strokeLinejoin="round"
|
||||||
className="h-4 w-4"
|
className="h-4 w-4 text-gray-600 dark:text-gray-400"
|
||||||
height="1em"
|
height="1em"
|
||||||
width="1em"
|
width="1em"
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
|
|
||||||
21
client/src/components/svg/ContinueIcon.tsx
Normal file
21
client/src/components/svg/ContinueIcon.tsx
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
import { cn } from '~/utils';
|
||||||
|
|
||||||
|
export default function ContinueIcon({ className = '' }: { className?: string }) {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
stroke="currentColor"
|
||||||
|
fill="none"
|
||||||
|
strokeWidth="2"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
strokeLinecap="round"
|
||||||
|
strokeLinejoin="round"
|
||||||
|
className={cn('h-3 w-3 -rotate-180 text-gray-600 dark:text-gray-400', className)}
|
||||||
|
height="1em"
|
||||||
|
width="1em"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
>
|
||||||
|
<polygon points="11 19 2 12 11 5 11 19" />
|
||||||
|
<polygon points="22 19 13 12 22 5 22 19" />
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import React from 'react';
|
import { cn } from '~/utils';
|
||||||
|
|
||||||
export default function Regenerate() {
|
export default function RegenerateIcon({ className = '' }: { className?: string }) {
|
||||||
return (
|
return (
|
||||||
<svg
|
<svg
|
||||||
stroke="currentColor"
|
stroke="currentColor"
|
||||||
|
|
@ -9,7 +9,7 @@ export default function Regenerate() {
|
||||||
viewBox="0 0 24 24"
|
viewBox="0 0 24 24"
|
||||||
strokeLinecap="round"
|
strokeLinecap="round"
|
||||||
strokeLinejoin="round"
|
strokeLinejoin="round"
|
||||||
className="h-4 w-4"
|
className={cn('h-4 w-4 text-gray-600 dark:text-gray-400', className)}
|
||||||
height="1em"
|
height="1em"
|
||||||
width="1em"
|
width="1em"
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import React from 'react';
|
import { cn } from '~/utils';
|
||||||
|
|
||||||
export default function StopGeneratingIcon() {
|
export default function StopGeneratingIcon({ className = '' }: { className?: string }) {
|
||||||
return (
|
return (
|
||||||
<svg
|
<svg
|
||||||
stroke="currentColor"
|
stroke="currentColor"
|
||||||
|
|
@ -9,7 +9,7 @@ export default function StopGeneratingIcon() {
|
||||||
viewBox="0 0 24 24"
|
viewBox="0 0 24 24"
|
||||||
strokeLinecap="round"
|
strokeLinecap="round"
|
||||||
strokeLinejoin="round"
|
strokeLinejoin="round"
|
||||||
className="h-3 w-3"
|
className={cn('h-3 w-3 text-gray-600 dark:text-gray-400', className)}
|
||||||
height="1em"
|
height="1em"
|
||||||
width="1em"
|
width="1em"
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ export { default as CrossIcon } from './CrossIcon';
|
||||||
export { default as LogOutIcon } from './LogOutIcon';
|
export { default as LogOutIcon } from './LogOutIcon';
|
||||||
export { default as MessagesSquared } from './MessagesSquared';
|
export { default as MessagesSquared } from './MessagesSquared';
|
||||||
export { default as StopGeneratingIcon } from './StopGeneratingIcon';
|
export { default as StopGeneratingIcon } from './StopGeneratingIcon';
|
||||||
|
export { default as RegenerateIcon } from './RegenerateIcon';
|
||||||
|
export { default as ContinueIcon } from './ContinueIcon';
|
||||||
export { default as GoogleIcon } from './GoogleIcon';
|
export { default as GoogleIcon } from './GoogleIcon';
|
||||||
export { default as OpenIDIcon } from './OpenIDIcon';
|
export { default as OpenIDIcon } from './OpenIDIcon';
|
||||||
export { default as GithubIcon } from './GithubIcon';
|
export { default as GithubIcon } from './GithubIcon';
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ function SelectDropDown({
|
||||||
{({ open }) => (
|
{({ open }) => (
|
||||||
<>
|
<>
|
||||||
<Listbox.Button
|
<Listbox.Button
|
||||||
|
data-testid="select-dropdown-button"
|
||||||
className={cn(
|
className={cn(
|
||||||
'relative flex w-full cursor-default flex-col rounded-md border border-black/10 bg-white py-2 pl-3 pr-10 text-left focus:outline-none focus:ring-0 focus:ring-offset-0 dark:border-white/20 dark:bg-gray-800 sm:text-sm',
|
'relative flex w-full cursor-default flex-col rounded-md border border-black/10 bg-white py-2 pl-3 pr-10 text-left focus:outline-none focus:ring-0 focus:ring-offset-0 dark:border-white/20 dark:bg-gray-800 sm:text-sm',
|
||||||
className ?? '',
|
className ?? '',
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,5 @@ export { default as useDebounce } from './useDebounce';
|
||||||
export { default as useLocalize } from './useLocalize';
|
export { default as useLocalize } from './useLocalize';
|
||||||
export { default as useMediaQuery } from './useMediaQuery';
|
export { default as useMediaQuery } from './useMediaQuery';
|
||||||
export { default as useSetOptions } from './useSetOptions';
|
export { default as useSetOptions } from './useSetOptions';
|
||||||
|
export { default as useGenerations } from './useGenerations';
|
||||||
export { default as useMessageHandler } from './useMessageHandler';
|
export { default as useMessageHandler } from './useMessageHandler';
|
||||||
|
|
|
||||||
55
client/src/hooks/useGenerations.ts
Normal file
55
client/src/hooks/useGenerations.ts
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
import type { TMessage } from 'librechat-data-provider';
|
||||||
|
import { useRecoilValue } from 'recoil';
|
||||||
|
import store from '~/store';
|
||||||
|
|
||||||
|
type TUseGenerations = {
|
||||||
|
endpoint?: string;
|
||||||
|
message: TMessage;
|
||||||
|
isSubmitting: boolean;
|
||||||
|
isEditing?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function useGenerations({
|
||||||
|
endpoint,
|
||||||
|
message,
|
||||||
|
isSubmitting,
|
||||||
|
isEditing = false,
|
||||||
|
}: TUseGenerations) {
|
||||||
|
const latestMessage = useRecoilValue(store.latestMessage);
|
||||||
|
|
||||||
|
const { error, messageId, searchResult, finish_reason, isCreatedByUser } = message ?? {};
|
||||||
|
|
||||||
|
const continueSupported =
|
||||||
|
latestMessage?.messageId === messageId &&
|
||||||
|
finish_reason &&
|
||||||
|
finish_reason !== 'stop' &&
|
||||||
|
!!['azureOpenAI', 'openAI', 'gptPlugins', 'anthropic'].find((e) => e === endpoint);
|
||||||
|
|
||||||
|
const branchingSupported =
|
||||||
|
// 5/21/23: Bing is allowing editing and Message regenerating
|
||||||
|
!![
|
||||||
|
'azureOpenAI',
|
||||||
|
'openAI',
|
||||||
|
'chatGPTBrowser',
|
||||||
|
'google',
|
||||||
|
'bingAI',
|
||||||
|
'gptPlugins',
|
||||||
|
'anthropic',
|
||||||
|
].find((e) => e === endpoint);
|
||||||
|
|
||||||
|
const editEnabled =
|
||||||
|
!error &&
|
||||||
|
isCreatedByUser && // TODO: allow AI editing
|
||||||
|
!searchResult &&
|
||||||
|
!isEditing &&
|
||||||
|
branchingSupported;
|
||||||
|
|
||||||
|
const regenerateEnabled =
|
||||||
|
!isCreatedByUser && !searchResult && !isEditing && !isSubmitting && branchingSupported;
|
||||||
|
|
||||||
|
return {
|
||||||
|
continueSupported,
|
||||||
|
editEnabled,
|
||||||
|
regenerateEnabled,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
@ -1,219 +0,0 @@
|
||||||
import { v4 } from 'uuid';
|
|
||||||
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
|
|
||||||
import store from '~/store';
|
|
||||||
|
|
||||||
const useMessageHandler = () => {
|
|
||||||
const currentConversation = useRecoilValue(store.conversation) || {};
|
|
||||||
const setSubmission = useSetRecoilState(store.submission);
|
|
||||||
const isSubmitting = useRecoilValue(store.isSubmitting);
|
|
||||||
const endpointsConfig = useRecoilValue(store.endpointsConfig);
|
|
||||||
|
|
||||||
const { getToken } = store.useToken(currentConversation?.endpoint);
|
|
||||||
|
|
||||||
const latestMessage = useRecoilValue(store.latestMessage);
|
|
||||||
|
|
||||||
const [messages, setMessages] = useRecoilState(store.messages);
|
|
||||||
|
|
||||||
const ask = (
|
|
||||||
{ text, parentMessageId = null, conversationId = null, messageId = null },
|
|
||||||
{ isRegenerate = false } = {},
|
|
||||||
) => {
|
|
||||||
if (!!isSubmitting || text === '') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// determine the model to be used
|
|
||||||
const { endpoint } = currentConversation;
|
|
||||||
let endpointOption = {};
|
|
||||||
let responseSender = '';
|
|
||||||
if (endpoint === 'azureOpenAI' || endpoint === 'openAI') {
|
|
||||||
endpointOption = {
|
|
||||||
endpoint,
|
|
||||||
model:
|
|
||||||
currentConversation?.model ??
|
|
||||||
endpointsConfig[endpoint]?.availableModels?.[0] ??
|
|
||||||
'gpt-3.5-turbo',
|
|
||||||
chatGptLabel: currentConversation?.chatGptLabel ?? null,
|
|
||||||
promptPrefix: currentConversation?.promptPrefix ?? null,
|
|
||||||
temperature: currentConversation?.temperature ?? 1,
|
|
||||||
top_p: currentConversation?.top_p ?? 1,
|
|
||||||
presence_penalty: currentConversation?.presence_penalty ?? 0,
|
|
||||||
frequency_penalty: currentConversation?.frequency_penalty ?? 0,
|
|
||||||
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
|
|
||||||
};
|
|
||||||
responseSender = endpointOption.chatGptLabel ?? 'ChatGPT';
|
|
||||||
} else if (endpoint === 'google') {
|
|
||||||
endpointOption = {
|
|
||||||
endpoint,
|
|
||||||
model:
|
|
||||||
currentConversation?.model ??
|
|
||||||
endpointsConfig[endpoint]?.availableModels?.[0] ??
|
|
||||||
'chat-bison',
|
|
||||||
modelLabel: currentConversation?.modelLabel ?? null,
|
|
||||||
promptPrefix: currentConversation?.promptPrefix ?? null,
|
|
||||||
examples: currentConversation?.examples ?? [
|
|
||||||
{ input: { content: '' }, output: { content: '' } },
|
|
||||||
],
|
|
||||||
temperature: currentConversation?.temperature ?? 0.2,
|
|
||||||
maxOutputTokens: currentConversation?.maxOutputTokens ?? 1024,
|
|
||||||
topP: currentConversation?.topP ?? 0.95,
|
|
||||||
topK: currentConversation?.topK ?? 40,
|
|
||||||
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
|
|
||||||
};
|
|
||||||
responseSender = endpointOption.chatGptLabel ?? 'ChatGPT';
|
|
||||||
} else if (endpoint === 'bingAI') {
|
|
||||||
endpointOption = {
|
|
||||||
endpoint,
|
|
||||||
jailbreak: currentConversation?.jailbreak ?? false,
|
|
||||||
systemMessage: currentConversation?.systemMessage ?? null,
|
|
||||||
context: currentConversation?.context ?? null,
|
|
||||||
toneStyle: currentConversation?.toneStyle ?? 'creative',
|
|
||||||
jailbreakConversationId: currentConversation?.jailbreakConversationId ?? null,
|
|
||||||
conversationSignature: currentConversation?.conversationSignature ?? null,
|
|
||||||
clientId: currentConversation?.clientId ?? null,
|
|
||||||
invocationId: currentConversation?.invocationId ?? 1,
|
|
||||||
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
|
|
||||||
};
|
|
||||||
responseSender = endpointOption.jailbreak ? 'Sydney' : 'BingAI';
|
|
||||||
} else if (endpoint === 'anthropic') {
|
|
||||||
endpointOption = {
|
|
||||||
endpoint,
|
|
||||||
model:
|
|
||||||
currentConversation?.model ??
|
|
||||||
endpointsConfig[endpoint]?.availableModels?.[0] ??
|
|
||||||
'claude-1',
|
|
||||||
modelLabel: currentConversation?.modelLabel ?? null,
|
|
||||||
promptPrefix: currentConversation?.promptPrefix ?? null,
|
|
||||||
temperature: currentConversation?.temperature ?? 1,
|
|
||||||
maxOutputTokens: currentConversation?.maxOutputTokens ?? 1024,
|
|
||||||
topP: currentConversation?.topP ?? 0.7,
|
|
||||||
topK: currentConversation?.topK ?? 5,
|
|
||||||
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
|
|
||||||
};
|
|
||||||
responseSender = 'Anthropic';
|
|
||||||
} else if (endpoint === 'chatGPTBrowser') {
|
|
||||||
endpointOption = {
|
|
||||||
endpoint,
|
|
||||||
model:
|
|
||||||
currentConversation?.model ??
|
|
||||||
endpointsConfig[endpoint]?.availableModels?.[0] ??
|
|
||||||
'text-davinci-002-render-sha',
|
|
||||||
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
|
|
||||||
};
|
|
||||||
responseSender = 'ChatGPT';
|
|
||||||
} else if (endpoint === 'gptPlugins') {
|
|
||||||
const agentOptions = currentConversation?.agentOptions ?? {
|
|
||||||
agent: 'functions',
|
|
||||||
skipCompletion: true,
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
temperature: 0,
|
|
||||||
};
|
|
||||||
endpointOption = {
|
|
||||||
endpoint,
|
|
||||||
tools: currentConversation?.tools ?? [],
|
|
||||||
model:
|
|
||||||
currentConversation?.model ??
|
|
||||||
endpointsConfig[endpoint]?.availableModels?.[0] ??
|
|
||||||
'gpt-3.5-turbo',
|
|
||||||
chatGptLabel: currentConversation?.chatGptLabel ?? null,
|
|
||||||
promptPrefix: currentConversation?.promptPrefix ?? null,
|
|
||||||
temperature: currentConversation?.temperature ?? 0.8,
|
|
||||||
top_p: currentConversation?.top_p ?? 1,
|
|
||||||
presence_penalty: currentConversation?.presence_penalty ?? 0,
|
|
||||||
frequency_penalty: currentConversation?.frequency_penalty ?? 0,
|
|
||||||
token: endpointsConfig[endpoint]?.userProvide ? getToken() : null,
|
|
||||||
agentOptions,
|
|
||||||
};
|
|
||||||
responseSender = 'ChatGPT';
|
|
||||||
} else if (endpoint === null) {
|
|
||||||
console.error('No endpoint available');
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
console.error(`Unknown endpoint ${endpoint}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let currentMessages = messages;
|
|
||||||
|
|
||||||
// construct the query message
|
|
||||||
// this is not a real messageId, it is used as placeholder before real messageId returned
|
|
||||||
text = text.trim();
|
|
||||||
const fakeMessageId = v4();
|
|
||||||
parentMessageId =
|
|
||||||
parentMessageId || latestMessage?.messageId || '00000000-0000-0000-0000-000000000000';
|
|
||||||
conversationId = conversationId || currentConversation?.conversationId;
|
|
||||||
if (conversationId == 'search') {
|
|
||||||
console.error('cannot send any message under search view!');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (conversationId == 'new') {
|
|
||||||
parentMessageId = '00000000-0000-0000-0000-000000000000';
|
|
||||||
currentMessages = [];
|
|
||||||
conversationId = null;
|
|
||||||
}
|
|
||||||
const currentMsg = {
|
|
||||||
sender: 'User',
|
|
||||||
text,
|
|
||||||
current: true,
|
|
||||||
isCreatedByUser: true,
|
|
||||||
parentMessageId,
|
|
||||||
conversationId,
|
|
||||||
messageId: fakeMessageId,
|
|
||||||
};
|
|
||||||
|
|
||||||
// construct the placeholder response message
|
|
||||||
const initialResponse = {
|
|
||||||
sender: responseSender,
|
|
||||||
text: '<span className="result-streaming">█</span>',
|
|
||||||
parentMessageId: isRegenerate ? messageId : fakeMessageId,
|
|
||||||
messageId: (isRegenerate ? messageId : fakeMessageId) + '_',
|
|
||||||
conversationId,
|
|
||||||
unfinished: false,
|
|
||||||
submitting: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
const submission = {
|
|
||||||
conversation: {
|
|
||||||
...currentConversation,
|
|
||||||
conversationId,
|
|
||||||
},
|
|
||||||
endpointOption,
|
|
||||||
message: {
|
|
||||||
...currentMsg,
|
|
||||||
overrideParentMessageId: isRegenerate ? messageId : null,
|
|
||||||
},
|
|
||||||
messages: currentMessages,
|
|
||||||
isRegenerate,
|
|
||||||
initialResponse,
|
|
||||||
};
|
|
||||||
|
|
||||||
console.log('User Input:', text, submission);
|
|
||||||
|
|
||||||
if (isRegenerate) {
|
|
||||||
setMessages([...currentMessages, initialResponse]);
|
|
||||||
} else {
|
|
||||||
setMessages([...currentMessages, currentMsg, initialResponse]);
|
|
||||||
}
|
|
||||||
setSubmission(submission);
|
|
||||||
};
|
|
||||||
|
|
||||||
const regenerate = ({ parentMessageId }) => {
|
|
||||||
const parentMessage = messages?.find((element) => element.messageId == parentMessageId);
|
|
||||||
|
|
||||||
if (parentMessage && parentMessage.isCreatedByUser) {
|
|
||||||
ask({ ...parentMessage }, { isRegenerate: true });
|
|
||||||
} else {
|
|
||||||
console.error(
|
|
||||||
'Failed to regenerate the message: parentMessage not found or not created by user.',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const stopGenerating = () => {
|
|
||||||
setSubmission(null);
|
|
||||||
};
|
|
||||||
|
|
||||||
return { ask, regenerate, stopGenerating };
|
|
||||||
};
|
|
||||||
|
|
||||||
export default useMessageHandler;
|
|
||||||
201
client/src/hooks/useMessageHandler.ts
Normal file
201
client/src/hooks/useMessageHandler.ts
Normal file
|
|
@ -0,0 +1,201 @@
|
||||||
|
import { v4 } from 'uuid';
|
||||||
|
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
|
||||||
|
import { parseConvo, getResponseSender } from 'librechat-data-provider';
|
||||||
|
import type { TMessage } from 'librechat-data-provider';
|
||||||
|
import store from '~/store';
|
||||||
|
|
||||||
|
type TAskProps = {
|
||||||
|
text: string;
|
||||||
|
parentMessageId?: string | null;
|
||||||
|
conversationId?: string | null;
|
||||||
|
messageId?: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const useMessageHandler = () => {
|
||||||
|
const currentConversation = useRecoilValue(store.conversation) || { endpoint: null };
|
||||||
|
const setSubmission = useSetRecoilState(store.submission);
|
||||||
|
const isSubmitting = useRecoilValue(store.isSubmitting);
|
||||||
|
const endpointsConfig = useRecoilValue(store.endpointsConfig);
|
||||||
|
const latestMessage = useRecoilValue(store.latestMessage);
|
||||||
|
const [messages, setMessages] = useRecoilState(store.messages);
|
||||||
|
const { endpoint } = currentConversation;
|
||||||
|
const { getToken } = store.useToken(endpoint ?? '');
|
||||||
|
|
||||||
|
const ask = (
|
||||||
|
{ text, parentMessageId = null, conversationId = null, messageId = null }: TAskProps,
|
||||||
|
{ isRegenerate = false, isEdited = false } = {},
|
||||||
|
) => {
|
||||||
|
if (!!isSubmitting || text === '') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (endpoint === null) {
|
||||||
|
console.error('No endpoint available');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationId = conversationId ?? currentConversation?.conversationId;
|
||||||
|
if (conversationId == 'search') {
|
||||||
|
console.error('cannot send any message under search view!');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isEdited && !latestMessage) {
|
||||||
|
console.error('cannot edit AI message without latestMessage!');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { userProvide } = endpointsConfig[endpoint] ?? {};
|
||||||
|
|
||||||
|
// set the endpoint option
|
||||||
|
const convo = parseConvo(endpoint, currentConversation);
|
||||||
|
const endpointOption = {
|
||||||
|
endpoint,
|
||||||
|
...convo,
|
||||||
|
token: userProvide ? getToken() : null,
|
||||||
|
};
|
||||||
|
const responseSender = getResponseSender(endpointOption);
|
||||||
|
|
||||||
|
let currentMessages: TMessage[] | null = messages ?? [];
|
||||||
|
|
||||||
|
// construct the query message
|
||||||
|
// this is not a real messageId, it is used as placeholder before real messageId returned
|
||||||
|
text = text.trim();
|
||||||
|
const fakeMessageId = v4();
|
||||||
|
parentMessageId =
|
||||||
|
parentMessageId || latestMessage?.messageId || '00000000-0000-0000-0000-000000000000';
|
||||||
|
|
||||||
|
if (conversationId == 'new') {
|
||||||
|
parentMessageId = '00000000-0000-0000-0000-000000000000';
|
||||||
|
currentMessages = [];
|
||||||
|
conversationId = null;
|
||||||
|
}
|
||||||
|
const currentMsg: TMessage = {
|
||||||
|
sender: 'User',
|
||||||
|
text,
|
||||||
|
current: true,
|
||||||
|
isCreatedByUser: true,
|
||||||
|
parentMessageId,
|
||||||
|
conversationId,
|
||||||
|
messageId: isEdited && messageId ? messageId : fakeMessageId,
|
||||||
|
error: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
// construct the placeholder response message
|
||||||
|
const generation = latestMessage?.text ?? '';
|
||||||
|
const responseText = isEdited ? generation : '<span className="result-streaming">█</span>';
|
||||||
|
|
||||||
|
const responseMessageId = isEdited ? latestMessage?.messageId : null;
|
||||||
|
const initialResponse: TMessage = {
|
||||||
|
sender: responseSender,
|
||||||
|
text: responseText,
|
||||||
|
parentMessageId: isRegenerate ? messageId : fakeMessageId,
|
||||||
|
messageId: responseMessageId ?? `${isRegenerate ? messageId : fakeMessageId}_`,
|
||||||
|
conversationId,
|
||||||
|
unfinished: false,
|
||||||
|
submitting: true,
|
||||||
|
isCreatedByUser: false,
|
||||||
|
error: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
const submission = {
|
||||||
|
conversation: {
|
||||||
|
...currentConversation,
|
||||||
|
conversationId,
|
||||||
|
},
|
||||||
|
endpointOption,
|
||||||
|
message: {
|
||||||
|
...currentMsg,
|
||||||
|
generation,
|
||||||
|
responseMessageId,
|
||||||
|
overrideParentMessageId: isRegenerate ? messageId : null,
|
||||||
|
},
|
||||||
|
messages: currentMessages,
|
||||||
|
isEdited,
|
||||||
|
isRegenerate,
|
||||||
|
initialResponse,
|
||||||
|
};
|
||||||
|
|
||||||
|
console.log('User Input:', text, submission);
|
||||||
|
|
||||||
|
if (isRegenerate) {
|
||||||
|
setMessages([
|
||||||
|
...(isEdited ? currentMessages.slice(0, -1) : currentMessages),
|
||||||
|
initialResponse,
|
||||||
|
]);
|
||||||
|
} else {
|
||||||
|
setMessages([...currentMessages, currentMsg, initialResponse]);
|
||||||
|
}
|
||||||
|
setSubmission(submission);
|
||||||
|
};
|
||||||
|
|
||||||
|
const regenerate = ({ parentMessageId }) => {
|
||||||
|
const parentMessage = messages?.find((element) => element.messageId == parentMessageId);
|
||||||
|
|
||||||
|
if (parentMessage && parentMessage.isCreatedByUser) {
|
||||||
|
ask({ ...parentMessage }, { isRegenerate: true });
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
'Failed to regenerate the message: parentMessage not found or not created by user.',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const continueGeneration = () => {
|
||||||
|
if (!latestMessage) {
|
||||||
|
console.error('Failed to regenerate the message: latestMessage not found.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const parentMessage = messages?.find(
|
||||||
|
(element) => element.messageId == latestMessage.parentMessageId,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (parentMessage && parentMessage.isCreatedByUser) {
|
||||||
|
ask({ ...parentMessage }, { isRegenerate: true, isEdited: true });
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
'Failed to regenerate the message: parentMessage not found, or not created by user.',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const stopGenerating = () => {
|
||||||
|
setSubmission(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleStopGenerating = (e: React.MouseEvent<HTMLButtonElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
stopGenerating();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleRegenerate = (e: React.MouseEvent<HTMLButtonElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
const parentMessageId = latestMessage?.parentMessageId;
|
||||||
|
if (!parentMessageId) {
|
||||||
|
console.error('Failed to regenerate the message: parentMessageId not found.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
regenerate({ parentMessageId });
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleContinue = (e: React.MouseEvent<HTMLButtonElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
continueGeneration();
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
ask,
|
||||||
|
regenerate,
|
||||||
|
stopGenerating,
|
||||||
|
handleStopGenerating,
|
||||||
|
handleRegenerate,
|
||||||
|
handleContinue,
|
||||||
|
endpointsConfig,
|
||||||
|
latestMessage,
|
||||||
|
isSubmitting,
|
||||||
|
messages,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export default useMessageHandler;
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import { useEffect } from 'react';
|
import { useEffect } from 'react';
|
||||||
import { useRecoilValue, useResetRecoilState, useSetRecoilState } from 'recoil';
|
import { useRecoilValue, useResetRecoilState, useSetRecoilState } from 'recoil';
|
||||||
import { SSE, createPayload } from 'librechat-data-provider';
|
import { SSE, createPayload } from 'librechat-data-provider';
|
||||||
import store from '~/store';
|
|
||||||
import { useAuthContext } from '~/hooks/AuthContext';
|
import { useAuthContext } from '~/hooks/AuthContext';
|
||||||
|
import store from '~/store';
|
||||||
|
|
||||||
export default function MessageHandler() {
|
export default function MessageHandler() {
|
||||||
const submission = useRecoilValue(store.submission);
|
const submission = useRecoilValue(store.submission);
|
||||||
|
|
@ -15,11 +15,18 @@ export default function MessageHandler() {
|
||||||
const { refreshConversations } = store.useConversations();
|
const { refreshConversations } = store.useConversations();
|
||||||
|
|
||||||
const messageHandler = (data, submission) => {
|
const messageHandler = (data, submission) => {
|
||||||
const { messages, message, plugin, initialResponse, isRegenerate = false } = submission;
|
const {
|
||||||
|
messages,
|
||||||
|
message,
|
||||||
|
plugin,
|
||||||
|
initialResponse,
|
||||||
|
isRegenerate = false,
|
||||||
|
isEdited = false,
|
||||||
|
} = submission;
|
||||||
|
|
||||||
if (isRegenerate) {
|
if (isRegenerate) {
|
||||||
setMessages([
|
setMessages([
|
||||||
...messages,
|
...(isEdited ? messages.slice(0, -1) : messages),
|
||||||
{
|
{
|
||||||
...initialResponse,
|
...initialResponse,
|
||||||
text: data,
|
text: data,
|
||||||
|
|
@ -48,13 +55,12 @@ export default function MessageHandler() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const cancelHandler = (data, submission) => {
|
const cancelHandler = (data, submission) => {
|
||||||
const { messages, isRegenerate = false } = submission;
|
|
||||||
|
|
||||||
const { requestMessage, responseMessage, conversation } = data;
|
const { requestMessage, responseMessage, conversation } = data;
|
||||||
|
const { messages, isRegenerate = false, isEdited = false } = submission;
|
||||||
|
|
||||||
// update the messages
|
// update the messages
|
||||||
if (isRegenerate) {
|
if (isRegenerate) {
|
||||||
setMessages([...messages, responseMessage]);
|
setMessages([...(isEdited ? messages.slice(0, -1) : messages), responseMessage]);
|
||||||
} else {
|
} else {
|
||||||
setMessages([...messages, requestMessage, responseMessage]);
|
setMessages([...messages, requestMessage, responseMessage]);
|
||||||
}
|
}
|
||||||
|
|
@ -79,11 +85,17 @@ export default function MessageHandler() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const createdHandler = (data, submission) => {
|
const createdHandler = (data, submission) => {
|
||||||
const { messages, message, initialResponse, isRegenerate = false } = submission;
|
const {
|
||||||
|
messages,
|
||||||
|
message,
|
||||||
|
initialResponse,
|
||||||
|
isRegenerate = false,
|
||||||
|
isEdited = false,
|
||||||
|
} = submission;
|
||||||
|
|
||||||
if (isRegenerate) {
|
if (isRegenerate) {
|
||||||
setMessages([
|
setMessages([
|
||||||
...messages,
|
...(isEdited ? messages.slice(0, -1) : messages),
|
||||||
{
|
{
|
||||||
...initialResponse,
|
...initialResponse,
|
||||||
parentMessageId: message?.overrideParentMessageId,
|
parentMessageId: message?.overrideParentMessageId,
|
||||||
|
|
@ -113,13 +125,12 @@ export default function MessageHandler() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const finalHandler = (data, submission) => {
|
const finalHandler = (data, submission) => {
|
||||||
const { messages, isRegenerate = false } = submission;
|
|
||||||
|
|
||||||
const { requestMessage, responseMessage, conversation } = data;
|
const { requestMessage, responseMessage, conversation } = data;
|
||||||
|
const { messages, isRegenerate = false, isEdited = false } = submission;
|
||||||
|
|
||||||
// update the messages
|
// update the messages
|
||||||
if (isRegenerate) {
|
if (isRegenerate) {
|
||||||
setMessages([...messages, responseMessage]);
|
setMessages([...(isEdited ? messages.slice(0, -1) : messages), responseMessage]);
|
||||||
} else {
|
} else {
|
||||||
setMessages([...messages, requestMessage, responseMessage]);
|
setMessages([...messages, requestMessage, responseMessage]);
|
||||||
}
|
}
|
||||||
|
|
@ -1,16 +1,16 @@
|
||||||
/* eslint-disable react-hooks/exhaustive-deps */
|
/* eslint-disable react-hooks/exhaustive-deps */
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
|
import { useSetRecoilState } from 'recoil';
|
||||||
|
import { Outlet } from 'react-router-dom';
|
||||||
import {
|
import {
|
||||||
useGetEndpointsQuery,
|
useGetEndpointsQuery,
|
||||||
useGetPresetsQuery,
|
useGetPresetsQuery,
|
||||||
useGetSearchEnabledQuery,
|
useGetSearchEnabledQuery,
|
||||||
} from 'librechat-data-provider';
|
} from 'librechat-data-provider';
|
||||||
|
|
||||||
import MessageHandler from '../components/MessageHandler';
|
import { Nav, MobileNav } from '~/components/Nav';
|
||||||
import { Nav, MobileNav } from '../components/Nav';
|
|
||||||
import { Outlet } from 'react-router-dom';
|
|
||||||
import { useAuthContext } from '~/hooks/AuthContext';
|
import { useAuthContext } from '~/hooks/AuthContext';
|
||||||
import { useSetRecoilState } from 'recoil';
|
import MessageHandler from './MessageHandler';
|
||||||
import store from '~/store';
|
import store from '~/store';
|
||||||
|
|
||||||
export default function Root() {
|
export default function Root() {
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,13 @@ import {
|
||||||
useResetRecoilState,
|
useResetRecoilState,
|
||||||
useRecoilCallback,
|
useRecoilCallback,
|
||||||
} from 'recoil';
|
} from 'recoil';
|
||||||
import { TConversation, TMessagesAtom, TSubmission, TPreset } from 'librechat-data-provider';
|
import {
|
||||||
|
TConversation,
|
||||||
|
TMessagesAtom,
|
||||||
|
TMessage,
|
||||||
|
TSubmission,
|
||||||
|
TPreset,
|
||||||
|
} from 'librechat-data-provider';
|
||||||
import { buildTree, getDefaultConversation } from '~/utils';
|
import { buildTree, getDefaultConversation } from '~/utils';
|
||||||
import submission from './submission';
|
import submission from './submission';
|
||||||
import endpoints from './endpoints';
|
import endpoints from './endpoints';
|
||||||
|
|
@ -32,7 +38,7 @@ const messagesTree = selector({
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const latestMessage = atom({
|
const latestMessage = atom<TMessage | null>({
|
||||||
key: 'latestMessage',
|
key: 'latestMessage',
|
||||||
default: null,
|
default: null,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ export default function buildTree(messages: TMessage[] | null, groupAll = false)
|
||||||
messages.forEach((message) => {
|
messages.forEach((message) => {
|
||||||
messageMap[message.messageId] = { ...message, children: [] };
|
messageMap[message.messageId] = { ...message, children: [] };
|
||||||
|
|
||||||
const parentMessage = messageMap[message.parentMessageId];
|
const parentMessage = messageMap[message.parentMessageId ?? ''];
|
||||||
if (parentMessage) {
|
if (parentMessage) {
|
||||||
parentMessage.children.push(messageMap[message.messageId]);
|
parentMessage.children.push(messageMap[message.messageId]);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue