mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-31 23:58:50 +01:00
Merge branch 'main' into re-add-download-audio
This commit is contained in:
commit
32d84c85ea
408 changed files with 16931 additions and 4946 deletions
|
|
@ -80,7 +80,7 @@ PROXY=
|
|||
#============#
|
||||
|
||||
ANTHROPIC_API_KEY=user_provided
|
||||
# ANTHROPIC_MODELS=claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
|
||||
# ANTHROPIC_MODELS=claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
|
||||
# ANTHROPIC_REVERSE_PROXY=
|
||||
|
||||
#============#
|
||||
|
|
@ -123,6 +123,8 @@ GOOGLE_KEY=user_provided
|
|||
# Vertex AI
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
|
||||
|
||||
# GOOGLE_TITLE_MODEL=gemini-pro
|
||||
|
||||
# Google Gemini Safety Settings
|
||||
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
|
||||
# To use this restricted HarmBlockThreshold setting, you will need to either:
|
||||
|
|
@ -372,6 +374,9 @@ LDAP_BIND_CREDENTIALS=
|
|||
LDAP_USER_SEARCH_BASE=
|
||||
LDAP_SEARCH_FILTER=mail={{username}}
|
||||
LDAP_CA_CERT_PATH=
|
||||
# LDAP_ID=
|
||||
# LDAP_USERNAME=
|
||||
# LDAP_FULL_NAME=
|
||||
|
||||
#========================#
|
||||
# Email Password Reset #
|
||||
|
|
|
|||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -11,6 +11,7 @@ logs
|
|||
pids
|
||||
*.pid
|
||||
*.seed
|
||||
.git
|
||||
|
||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||
lib-cov
|
||||
|
|
@ -45,6 +46,7 @@ api/node_modules/
|
|||
client/node_modules/
|
||||
bower_components/
|
||||
*.d.ts
|
||||
!vite-env.d.ts
|
||||
|
||||
# Floobits
|
||||
.floo
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@
|
|||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://railway.app/template/b5k2mn?referralCode=myKrVZ">
|
||||
<a href="https://railway.app/template/b5k2mn?referralCode=HI9hWz">
|
||||
<img src="https://railway.app/button.svg" alt="Deploy on Railway" height="30">
|
||||
</a>
|
||||
<a href="https://zeabur.com/templates/0X2ZY8">
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
const crypto = require('crypto');
|
||||
const fetch = require('node-fetch');
|
||||
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
||||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
|
@ -19,6 +19,14 @@ class BaseClient {
|
|||
day: 'numeric',
|
||||
});
|
||||
this.fetch = this.fetch.bind(this);
|
||||
/** @type {boolean} */
|
||||
this.skipSaveConvo = false;
|
||||
/** @type {boolean} */
|
||||
this.skipSaveUserMessage = false;
|
||||
/** @type {ClientDatabaseSavePromise} */
|
||||
this.userMessagePromise;
|
||||
/** @type {ClientDatabaseSavePromise} */
|
||||
this.responsePromise;
|
||||
}
|
||||
|
||||
setOptions() {
|
||||
|
|
@ -84,19 +92,45 @@ class BaseClient {
|
|||
await stream.processTextStream(onProgress);
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {[string|undefined, string|undefined]}
|
||||
*/
|
||||
processOverideIds() {
|
||||
/** @type {Record<string, string | undefined>} */
|
||||
let { overrideConvoId, overrideUserMessageId } = this.options?.req?.body ?? {};
|
||||
if (overrideConvoId) {
|
||||
const [conversationId, index] = overrideConvoId.split(Constants.COMMON_DIVIDER);
|
||||
overrideConvoId = conversationId;
|
||||
if (index !== '0') {
|
||||
this.skipSaveConvo = true;
|
||||
}
|
||||
}
|
||||
if (overrideUserMessageId) {
|
||||
const [userMessageId, index] = overrideUserMessageId.split(Constants.COMMON_DIVIDER);
|
||||
overrideUserMessageId = userMessageId;
|
||||
if (index !== '0') {
|
||||
this.skipSaveUserMessage = true;
|
||||
}
|
||||
}
|
||||
|
||||
return [overrideConvoId, overrideUserMessageId];
|
||||
}
|
||||
|
||||
async setMessageOptions(opts = {}) {
|
||||
if (opts && opts.replaceOptions) {
|
||||
this.setOptions(opts);
|
||||
}
|
||||
|
||||
const [overrideConvoId, overrideUserMessageId] = this.processOverideIds();
|
||||
const { isEdited, isContinued } = opts;
|
||||
const user = opts.user ?? null;
|
||||
this.user = user;
|
||||
const saveOptions = this.getSaveOptions();
|
||||
this.abortController = opts.abortController ?? new AbortController();
|
||||
const conversationId = opts.conversationId ?? crypto.randomUUID();
|
||||
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
|
||||
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
const userMessageId =
|
||||
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
|
||||
let head = isEdited ? responseMessageId : parentMessageId;
|
||||
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
|
||||
|
|
@ -160,7 +194,7 @@ class BaseClient {
|
|||
}
|
||||
|
||||
if (typeof opts?.onStart === 'function') {
|
||||
opts.onStart(userMessage);
|
||||
opts.onStart(userMessage, responseMessageId);
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
@ -450,8 +484,13 @@ class BaseClient {
|
|||
this.handleTokenCountMap(tokenCountMap);
|
||||
}
|
||||
|
||||
if (!isEdited) {
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (!isEdited && !this.skipSaveUserMessage) {
|
||||
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessagePromise: this.userMessagePromise,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
|
|
@ -500,15 +539,11 @@ class BaseClient {
|
|||
const completionTokens = this.getTokenCount(completion);
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens });
|
||||
}
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
delete responseMessage.tokenCount;
|
||||
return responseMessage;
|
||||
}
|
||||
|
||||
async getConversation(conversationId, user = null) {
|
||||
return await getConvo(user, conversationId);
|
||||
}
|
||||
|
||||
async loadHistory(conversationId, parentMessageId = null) {
|
||||
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
|
||||
|
||||
|
|
@ -563,18 +598,24 @@ class BaseClient {
|
|||
* @param {string | null} user
|
||||
*/
|
||||
async saveMessageToDatabase(message, endpointOptions, user = null) {
|
||||
await saveMessage({
|
||||
const savedMessage = await saveMessage({
|
||||
...message,
|
||||
endpoint: this.options.endpoint,
|
||||
unfinished: false,
|
||||
user,
|
||||
});
|
||||
await saveConvo(user, {
|
||||
|
||||
if (this.skipSaveConvo) {
|
||||
return { message: savedMessage };
|
||||
}
|
||||
const conversation = await saveConvo(user, {
|
||||
conversationId: message.conversationId,
|
||||
endpoint: this.options.endpoint,
|
||||
endpointType: this.options.endpointType,
|
||||
...endpointOptions,
|
||||
});
|
||||
|
||||
return { message: savedMessage, conversation };
|
||||
}
|
||||
|
||||
async updateMessageInDatabase(message) {
|
||||
|
|
|
|||
|
|
@ -16,10 +16,15 @@ const {
|
|||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { formatMessage, createContextHandlers } = require('./prompts');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
const {
|
||||
formatMessage,
|
||||
createContextHandlers,
|
||||
titleInstruction,
|
||||
truncateText,
|
||||
} = require('./prompts');
|
||||
const BaseClient = require('./BaseClient');
|
||||
|
||||
const loc = 'us-central1';
|
||||
const publisher = 'google';
|
||||
|
|
@ -591,12 +596,16 @@ class GoogleClient extends BaseClient {
|
|||
createLLM(clientOptions) {
|
||||
const model = clientOptions.modelName ?? clientOptions.model;
|
||||
if (this.project_id && this.isTextModel) {
|
||||
logger.debug('Creating Google VertexAI client');
|
||||
return new GoogleVertexAI(clientOptions);
|
||||
} else if (this.project_id && this.isChatModel) {
|
||||
logger.debug('Creating Chat Google VertexAI client');
|
||||
return new ChatGoogleVertexAI(clientOptions);
|
||||
} else if (this.project_id) {
|
||||
logger.debug('Creating VertexAI client');
|
||||
return new ChatVertexAI(clientOptions);
|
||||
} else if (model.includes('1.5')) {
|
||||
logger.debug('Creating GenAI client');
|
||||
return new GenAI(this.apiKey).getGenerativeModel(
|
||||
{
|
||||
...clientOptions,
|
||||
|
|
@ -606,6 +615,7 @@ class GoogleClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
|
||||
logger.debug('Creating Chat Google Generative AI client');
|
||||
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
|
||||
}
|
||||
|
||||
|
|
@ -717,6 +727,123 @@ class GoogleClient extends BaseClient {
|
|||
return reply;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
|
||||
*/
|
||||
async titleChatCompletion(_payload, options = {}) {
|
||||
const { abortController } = options;
|
||||
const { parameters, instances } = _payload;
|
||||
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};
|
||||
|
||||
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||
|
||||
logger.debug('Initialized title client options');
|
||||
|
||||
if (this.project_id) {
|
||||
clientOptions['authOptions'] = {
|
||||
credentials: {
|
||||
...this.serviceKey,
|
||||
},
|
||||
projectId: this.project_id,
|
||||
};
|
||||
}
|
||||
|
||||
if (!parameters) {
|
||||
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||
}
|
||||
|
||||
if (this.isGenerativeModel && !this.project_id) {
|
||||
clientOptions.modelName = clientOptions.model;
|
||||
delete clientOptions.model;
|
||||
}
|
||||
|
||||
const model = this.createLLM(clientOptions);
|
||||
|
||||
let reply = '';
|
||||
const messages = this.isTextModel ? _payload.trim() : _messages;
|
||||
|
||||
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
||||
if (modelName?.includes('1.5') && !this.project_id) {
|
||||
logger.debug('Identified titling model as 1.5 version');
|
||||
/** @type {GenerativeModel} */
|
||||
const client = model;
|
||||
const requestOptions = {
|
||||
contents: _payload,
|
||||
};
|
||||
|
||||
if (this.options?.promptPrefix?.length) {
|
||||
requestOptions.systemInstruction = {
|
||||
parts: [
|
||||
{
|
||||
text: this.options.promptPrefix,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const safetySettings = _payload.safetySettings;
|
||||
requestOptions.safetySettings = safetySettings;
|
||||
|
||||
const result = await client.generateContent(requestOptions);
|
||||
|
||||
reply = result.response?.text();
|
||||
|
||||
return reply;
|
||||
} else {
|
||||
logger.debug('Beginning titling');
|
||||
const safetySettings = _payload.safetySettings;
|
||||
|
||||
const titleResponse = await model.invoke(messages, {
|
||||
signal: abortController.signal,
|
||||
timeout: 7000,
|
||||
safetySettings: safetySettings,
|
||||
});
|
||||
|
||||
reply = titleResponse.content;
|
||||
|
||||
return reply;
|
||||
}
|
||||
}
|
||||
|
||||
async titleConvo({ text, responseText = '' }) {
|
||||
let title = 'New Chat';
|
||||
const convo = `||>User:
|
||||
"${truncateText(text)}"
|
||||
||>Response:
|
||||
"${JSON.stringify(truncateText(responseText))}"`;
|
||||
|
||||
let { prompt: payload } = await this.buildMessages([
|
||||
{
|
||||
text: `Please generate ${titleInstruction}
|
||||
|
||||
${convo}
|
||||
|
||||
||>Title:`,
|
||||
isCreatedByUser: true,
|
||||
author: this.userLabel,
|
||||
},
|
||||
]);
|
||||
|
||||
if (this.isVisionModel) {
|
||||
logger.warn(
|
||||
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
|
||||
);
|
||||
|
||||
payload.parameters = { ...payload.parameters, model: settings.model.default };
|
||||
}
|
||||
|
||||
try {
|
||||
title = await this.titleChatCompletion(payload, {
|
||||
abortController: new AbortController(),
|
||||
onProgress: () => {},
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error('[GoogleClient] There was an issue generating the title', e);
|
||||
}
|
||||
logger.debug(`Title response: ${title}`);
|
||||
return title;
|
||||
}
|
||||
|
||||
getSaveOptions() {
|
||||
return {
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ const {
|
|||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { updateTokenWebsocket } = require('~/server/services/Files/Audio');
|
||||
const { isEnabled, sleep } = require('~/server/utils');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
|
|
@ -595,7 +594,6 @@ class OpenAIClient extends BaseClient {
|
|||
payload,
|
||||
(progressMessage) => {
|
||||
if (progressMessage === '[DONE]') {
|
||||
updateTokenWebsocket('[DONE]');
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -238,12 +238,23 @@ class PluginsClient extends OpenAIClient {
|
|||
await this.recordTokenUsage(responseMessage);
|
||||
}
|
||||
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
delete responseMessage.tokenCount;
|
||||
return { ...responseMessage, ...result };
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
|
||||
this.options.tools = tools;
|
||||
} else {
|
||||
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
|
||||
this.options.tools = tools;
|
||||
}
|
||||
|
||||
// If a message is edited, no tools can be used.
|
||||
const completionMode = this.options.tools.length === 0 || opts.isEdited;
|
||||
if (completionMode) {
|
||||
|
|
@ -301,7 +312,15 @@ class PluginsClient extends OpenAIClient {
|
|||
if (payload) {
|
||||
this.currentMessages = payload;
|
||||
}
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
|
||||
if (!this.skipSaveUserMessage) {
|
||||
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessagePromise: this.userMessagePromise,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||
await checkBalance({
|
||||
|
|
|
|||
|
|
@ -1,44 +1,3 @@
|
|||
/*
|
||||
module.exports = `You are ChatGPT, a Large Language model with useful tools.
|
||||
|
||||
Talk to the human and provide meaningful answers when questions are asked.
|
||||
|
||||
Use the tools when you need them, but use your own knowledge if you are confident of the answer. Keep answers short and concise.
|
||||
|
||||
A tool is not usually needed for creative requests, so do your best to answer them without tools.
|
||||
|
||||
Avoid repeating identical answers if it appears before. Only fulfill the human's requests, do not create extra steps beyond what the human has asked for.
|
||||
|
||||
Your input for 'Action' should be the name of tool used only.
|
||||
|
||||
Be honest. If you can't answer something, or a tool is not appropriate, say you don't know or answer to the best of your ability.
|
||||
|
||||
Attempt to fulfill the human's requests in as few actions as possible`;
|
||||
*/
|
||||
|
||||
// module.exports = `You are ChatGPT, a highly knowledgeable and versatile large language model.
|
||||
|
||||
// Engage with the Human conversationally, providing concise and meaningful answers to questions. Utilize built-in tools when necessary, except for creative requests, where relying on your own knowledge is preferred. Aim for variety and avoid repetitive answers.
|
||||
|
||||
// For your 'Action' input, state the name of the tool used only, and honor user requests without adding extra steps. Always be honest; if you cannot provide an appropriate answer or tool, admit that or do your best.
|
||||
|
||||
// Strive to meet the user's needs efficiently with minimal actions.`;
|
||||
|
||||
// import {
|
||||
// BasePromptTemplate,
|
||||
// BaseStringPromptTemplate,
|
||||
// SerializedBasePromptTemplate,
|
||||
// renderTemplate,
|
||||
// } from "langchain/prompts";
|
||||
|
||||
// prefix: `You are ChatGPT, a highly knowledgeable and versatile large language model.
|
||||
// Your objective is to help users by understanding their intent and choosing the best action. Prioritize direct, specific responses. Use concise, varied answers and rely on your knowledge for creative tasks. Utilize tools when needed, and structure results for machine compatibility.
|
||||
// prefix: `Objective: to comprehend human intentions based on user input and available tools. Goal: identify the best action to directly address the human's query. In your subsequent steps, you will utilize the chosen action. You may select multiple actions and list them in a meaningful order. Prioritize actions that directly relate to the user's query over general ones. Ensure that the generated thought is highly specific and explicit to best match the user's expectations. Construct the result in a manner that an online open-API would most likely expect. Provide concise and meaningful answers to human queries. Utilize tools when necessary. Relying on your own knowledge is preferred for creative requests. Aim for variety and avoid repetitive answers.
|
||||
|
||||
// # Available Actions & Tools:
|
||||
// N/A: no suitable action, use your own knowledge.`,
|
||||
// suffix: `Remember, all your responses MUST adhere to the described format and only respond if the format is followed. Output exactly with the requested format, avoiding any other text as this will be parsed by a machine. Following 'Action:', provide only one of the actions listed above. If a tool is not necessary, deduce this quickly and finish your response. Honor the human's requests without adding extra steps. Carry out tasks in the sequence written by the human. Always be honest; if you cannot provide an appropriate answer or tool, do your best with your own knowledge. Strive to meet the user's needs efficiently with minimal actions.`;
|
||||
|
||||
module.exports = {
|
||||
'gpt3-v1': {
|
||||
prefix: `Objective: Understand human intentions using user input and available tools. Goal: Identify the most suitable actions to directly address user queries.
|
||||
|
|
|
|||
|
|
@ -576,7 +576,11 @@ describe('BaseClient', () => {
|
|||
const onStart = jest.fn();
|
||||
const opts = { onStart };
|
||||
await TestClient.sendMessage('Hello, world!', opts);
|
||||
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' }));
|
||||
|
||||
expect(onStart).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ text: 'Hello, world!' }),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
test('saveMessageToDatabase is called with the correct arguments', async () => {
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ describe('PluginsClient', () => {
|
|||
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Azure OpenAI tests specific to Plugins', () => {
|
||||
// TODO: add more tests for Azure OpenAI integration with Plugins
|
||||
// let client;
|
||||
|
|
@ -220,4 +221,94 @@ describe('PluginsClient', () => {
|
|||
spy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessage with filtered tools', () => {
|
||||
let TestAgent;
|
||||
const apiKey = 'fake-api-key';
|
||||
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
|
||||
|
||||
beforeEach(() => {
|
||||
TestAgent = new PluginsClient(apiKey, {
|
||||
tools: mockTools,
|
||||
modelOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
temperature: 0,
|
||||
max_tokens: 2,
|
||||
},
|
||||
agentOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
},
|
||||
});
|
||||
|
||||
TestAgent.options.req = {
|
||||
app: {
|
||||
locals: {},
|
||||
},
|
||||
};
|
||||
|
||||
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
|
||||
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
const tools = TestAgent.options.tools.filter((plugin) =>
|
||||
includedTools.includes(plugin.name),
|
||||
);
|
||||
TestAgent.options.tools = tools;
|
||||
} else {
|
||||
const tools = TestAgent.options.tools.filter(
|
||||
(plugin) => !filteredTools.includes(plugin.name),
|
||||
);
|
||||
TestAgent.options.tools = tools;
|
||||
}
|
||||
|
||||
return {
|
||||
text: 'Mocked response',
|
||||
tools: TestAgent.options.tools,
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
test('should filter out tools when filteredTools is provided', async () => {
|
||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
expect.objectContaining({ name: 'tool4' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should only include specified tools when includedTools is provided', async () => {
|
||||
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
expect.objectContaining({ name: 'tool4' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should prioritize includedTools over filteredTools', async () => {
|
||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool1' }),
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should not modify tools when no filters are provided', async () => {
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(4);
|
||||
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
5
api/cache/getLogStores.js
vendored
5
api/cache/getLogStores.js
vendored
|
|
@ -25,6 +25,10 @@ const config = isEnabled(USE_REDIS)
|
|||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const roles = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ROLES });
|
||||
|
||||
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES });
|
||||
|
|
@ -46,6 +50,7 @@ const abortKeys = isEnabled(USE_REDIS)
|
|||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });
|
||||
|
||||
const namespaces = {
|
||||
[CacheKeys.ROLES]: roles,
|
||||
[CacheKeys.CONFIG_STORE]: config,
|
||||
pending_req,
|
||||
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
|
||||
|
|
|
|||
61
api/models/Categories.js
Normal file
61
api/models/Categories.js
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
const { logger } = require('~/config');
|
||||
// const { Categories } = require('./schema/categories');
|
||||
const options = [
|
||||
{
|
||||
label: '',
|
||||
value: '',
|
||||
},
|
||||
{
|
||||
label: 'idea',
|
||||
value: 'idea',
|
||||
},
|
||||
{
|
||||
label: 'travel',
|
||||
value: 'travel',
|
||||
},
|
||||
{
|
||||
label: 'teach_or_explain',
|
||||
value: 'teach_or_explain',
|
||||
},
|
||||
{
|
||||
label: 'write',
|
||||
value: 'write',
|
||||
},
|
||||
{
|
||||
label: 'shop',
|
||||
value: 'shop',
|
||||
},
|
||||
{
|
||||
label: 'code',
|
||||
value: 'code',
|
||||
},
|
||||
{
|
||||
label: 'misc',
|
||||
value: 'misc',
|
||||
},
|
||||
{
|
||||
label: 'roleplay',
|
||||
value: 'roleplay',
|
||||
},
|
||||
{
|
||||
label: 'finance',
|
||||
value: 'finance',
|
||||
},
|
||||
];
|
||||
|
||||
module.exports = {
|
||||
/**
|
||||
* Retrieves the categories asynchronously.
|
||||
* @returns {Promise<TGetCategoriesResponse>} An array of category objects.
|
||||
* @throws {Error} If there is an error retrieving the categories.
|
||||
*/
|
||||
getCategories: async () => {
|
||||
try {
|
||||
// const categories = await Categories.find();
|
||||
return options;
|
||||
} catch (error) {
|
||||
logger.error('Error getting categories', error);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
@ -27,10 +27,12 @@ module.exports = {
|
|||
update.conversationId = newConversationId;
|
||||
}
|
||||
|
||||
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
const conversation = await Conversation.findOneAndUpdate({ conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
});
|
||||
|
||||
return conversation.toObject();
|
||||
} catch (error) {
|
||||
logger.error('[saveConvo] Error saving conversation', error);
|
||||
return { message: 'Error saving conversation' };
|
||||
|
|
|
|||
|
|
@ -57,18 +57,13 @@ module.exports = {
|
|||
if (files) {
|
||||
update.files = files;
|
||||
}
|
||||
// may also need to update the conversation here
|
||||
await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true });
|
||||
|
||||
return {
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
tokenCount,
|
||||
};
|
||||
const message = await Message.findOneAndUpdate({ messageId }, update, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
|
||||
return message.toObject();
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
throw new Error('Failed to save message.');
|
||||
|
|
|
|||
90
api/models/Project.js
Normal file
90
api/models/Project.js
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
const { model } = require('mongoose');
|
||||
const projectSchema = require('~/models/schema/projectSchema');
|
||||
|
||||
const Project = model('Project', projectSchema);
|
||||
|
||||
/**
|
||||
* Retrieve a project by ID and convert the found project document to a plain object.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to find and return as a plain object.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoProject>} A plain object representing the project document, or `null` if no project is found.
|
||||
*/
|
||||
const getProjectById = async function (projectId, fieldsToSelect = null) {
|
||||
const query = Project.findById(projectId);
|
||||
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieve a project by name and convert the found project document to a plain object.
|
||||
* If the project with the given name doesn't exist and the name is "instance", create it and return the lean version.
|
||||
*
|
||||
* @param {string} projectName - The name of the project to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoProject>} A plain object representing the project document.
|
||||
*/
|
||||
const getProjectByName = async function (projectName, fieldsToSelect = null) {
|
||||
const query = { name: projectName };
|
||||
const update = { $setOnInsert: { name: projectName } };
|
||||
const options = {
|
||||
new: true,
|
||||
upsert: projectName === 'instance',
|
||||
lean: true,
|
||||
select: fieldsToSelect,
|
||||
};
|
||||
|
||||
return await Project.findOneAndUpdate(query, update, options);
|
||||
};
|
||||
|
||||
/**
|
||||
* Add an array of prompt group IDs to a project's promptGroupIds array, ensuring uniqueness.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an array of prompt group IDs from a project's promptGroupIds array.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove a prompt group ID from all projects.
|
||||
*
|
||||
* @param {string} promptGroupId - The ID of the prompt group to remove from projects.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeGroupFromAllProjects = async (promptGroupId) => {
|
||||
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getProjectById,
|
||||
getProjectByName,
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
};
|
||||
|
|
@ -1,52 +1,528 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { SystemRoles, SystemCategories } = require('librechat-data-provider');
|
||||
const {
|
||||
getProjectByName,
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
} = require('./Project');
|
||||
const { Prompt, PromptGroup } = require('./schema/promptSchema');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const promptSchema = mongoose.Schema(
|
||||
{
|
||||
title: {
|
||||
type: String,
|
||||
required: true,
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get prompt groups
|
||||
* @param {Object} query
|
||||
* @param {number} skip
|
||||
* @param {number} limit
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createGroupPipeline = (query, skip, limit) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{ $skip: skip },
|
||||
{ $limit: limit },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
prompt: {
|
||||
type: String,
|
||||
required: true,
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project: {
|
||||
name: 1,
|
||||
numberOfGenerations: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
projectIds: 1,
|
||||
productionId: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
// 'productionPrompt._id': 1,
|
||||
// 'productionPrompt.type': 1,
|
||||
},
|
||||
},
|
||||
category: {
|
||||
type: String,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
];
|
||||
};
|
||||
|
||||
const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema);
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get all prompt groups
|
||||
* @param {Object} query
|
||||
* @param {Partial<MongoPromptGroup>} $project
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createAllGroupsPipeline = (
|
||||
query,
|
||||
$project = {
|
||||
name: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
command: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
},
|
||||
) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project,
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all prompt groups with filters
|
||||
* @param {Object} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getAllPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { name, ...query } = filter;
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(name, 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
const project = await getProjectByName('instance', 'promptGroupIds');
|
||||
if (project && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
|
||||
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
|
||||
} catch (error) {
|
||||
console.error('Error getting all prompt groups', error);
|
||||
return { message: 'Error getting all prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {Object} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { pageNumber = 1, pageSize = 10, name, ...query } = filter;
|
||||
|
||||
const validatedPageNumber = Math.max(parseInt(pageNumber, 10), 1);
|
||||
const validatedPageSize = Math.max(parseInt(pageSize, 10), 1);
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(name, 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
// const projects = req.user.projects || []; // TODO: handle multiple projects
|
||||
const project = await getProjectByName('instance', 'promptGroupIds');
|
||||
if (project && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const skip = (validatedPageNumber - 1) * validatedPageSize;
|
||||
const limit = validatedPageSize;
|
||||
|
||||
const promptGroupsPipeline = createGroupPipeline(combinedQuery, skip, limit);
|
||||
const totalPromptGroupsPipeline = [{ $match: combinedQuery }, { $count: 'total' }];
|
||||
|
||||
const [promptGroupsResults, totalPromptGroupsResults] = await Promise.all([
|
||||
PromptGroup.aggregate(promptGroupsPipeline).exec(),
|
||||
PromptGroup.aggregate(totalPromptGroupsPipeline).exec(),
|
||||
]);
|
||||
|
||||
const promptGroups = promptGroupsResults;
|
||||
const totalPromptGroups =
|
||||
totalPromptGroupsResults.length > 0 ? totalPromptGroupsResults[0].total : 0;
|
||||
|
||||
return {
|
||||
promptGroups,
|
||||
pageNumber: validatedPageNumber.toString(),
|
||||
pageSize: validatedPageSize.toString(),
|
||||
pages: Math.ceil(totalPromptGroups / validatedPageSize).toString(),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
savePrompt: async ({ title, prompt }) => {
|
||||
getPromptGroups,
|
||||
getAllPromptGroups,
|
||||
/**
|
||||
* Create a prompt and its respective group
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
createPromptGroup: async (saveData) => {
|
||||
try {
|
||||
await Prompt.create({
|
||||
title,
|
||||
prompt,
|
||||
});
|
||||
return { title, prompt };
|
||||
const { prompt, group, author, authorName } = saveData;
|
||||
|
||||
let newPromptGroup = await PromptGroup.findOneAndUpdate(
|
||||
{ ...group, author, authorName, productionId: null },
|
||||
{ $setOnInsert: { ...group, author, authorName, productionId: null } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
const newPrompt = await Prompt.findOneAndUpdate(
|
||||
{ ...prompt, author, groupId: newPromptGroup._id },
|
||||
{ $setOnInsert: { ...prompt, author, groupId: newPromptGroup._id } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
newPromptGroup = await PromptGroup.findByIdAndUpdate(
|
||||
newPromptGroup._id,
|
||||
{ productionId: newPrompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
return {
|
||||
prompt: newPrompt,
|
||||
group: {
|
||||
...newPromptGroup,
|
||||
productionPrompt: { prompt: newPrompt.prompt },
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt group', error);
|
||||
throw new Error('Error saving prompt group');
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Save a prompt
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
savePrompt: async (saveData) => {
|
||||
try {
|
||||
const { prompt, author } = saveData;
|
||||
const newPromptData = {
|
||||
...prompt,
|
||||
author,
|
||||
};
|
||||
|
||||
/** @type {TPrompt} */
|
||||
let newPrompt;
|
||||
try {
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
} catch (error) {
|
||||
if (error?.message?.includes('groupId_1_version_1')) {
|
||||
await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1');
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
}
|
||||
|
||||
return { prompt: newPrompt };
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt', error);
|
||||
return { prompt: 'Error saving prompt' };
|
||||
return { message: 'Error saving prompt' };
|
||||
}
|
||||
},
|
||||
getPrompts: async (filter) => {
|
||||
try {
|
||||
return await Prompt.find(filter).lean();
|
||||
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompts', error);
|
||||
return { prompt: 'Error getting prompts' };
|
||||
return { message: 'Error getting prompts' };
|
||||
}
|
||||
},
|
||||
deletePrompts: async (filter) => {
|
||||
getPrompt: async (filter) => {
|
||||
try {
|
||||
return await Prompt.deleteMany(filter);
|
||||
if (filter.groupId) {
|
||||
filter.groupId = new ObjectId(filter.groupId);
|
||||
}
|
||||
return await Prompt.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error deleting prompts', error);
|
||||
return { prompt: 'Error deleting prompts' };
|
||||
logger.error('Error getting prompt', error);
|
||||
return { message: 'Error getting prompt' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {TGetRandomPromptsRequest} filter
|
||||
* @returns {Promise<TGetRandomPromptsResponse>}
|
||||
*/
|
||||
getRandomPromptGroups: async (filter) => {
|
||||
try {
|
||||
const result = await PromptGroup.aggregate([
|
||||
{
|
||||
$match: {
|
||||
category: { $ne: '' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$group: {
|
||||
_id: '$category',
|
||||
promptGroup: { $first: '$$ROOT' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$replaceRoot: { newRoot: '$promptGroup' },
|
||||
},
|
||||
{
|
||||
$sample: { size: +filter.limit + +filter.skip },
|
||||
},
|
||||
{
|
||||
$skip: +filter.skip,
|
||||
},
|
||||
{
|
||||
$limit: +filter.limit,
|
||||
},
|
||||
]);
|
||||
return { prompts: result };
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroupsWithPrompts: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter)
|
||||
.populate({
|
||||
path: 'prompts',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroup: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
return { message: 'Error getting prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Deletes a prompt and its corresponding prompt group if it is the last prompt in the group.
|
||||
*
|
||||
* @param {Object} options - The options for deleting the prompt.
|
||||
* @param {ObjectId|string} options.promptId - The ID of the prompt to delete.
|
||||
* @param {ObjectId|string} options.groupId - The ID of the prompt's group.
|
||||
* @param {ObjectId|string} options.author - The ID of the prompt's author.
|
||||
* @param {string} options.role - The role of the prompt's author.
|
||||
* @return {Promise<TDeletePromptResponse>} An object containing the result of the deletion.
|
||||
* If the prompt was deleted successfully, the object will have a property 'prompt' with the value 'Prompt deleted successfully'.
|
||||
* If the prompt group was deleted successfully, the object will have a property 'promptGroup' with the message 'Prompt group deleted successfully' and id of the deleted group.
|
||||
* If there was an error deleting the prompt, the object will have a property 'message' with the value 'Error deleting prompt'.
|
||||
*/
|
||||
deletePrompt: async ({ promptId, groupId, author, role }) => {
|
||||
const query = { _id: promptId, groupId, author };
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const { deletedCount } = await Prompt.deleteOne(query);
|
||||
if (deletedCount === 0) {
|
||||
throw new Error('Failed to delete the prompt');
|
||||
}
|
||||
|
||||
const remainingPrompts = await Prompt.find({ groupId })
|
||||
.select('_id')
|
||||
.sort({ createdAt: 1 })
|
||||
.lean();
|
||||
|
||||
if (remainingPrompts.length === 0) {
|
||||
await PromptGroup.deleteOne({ _id: groupId });
|
||||
await removeGroupFromAllProjects(groupId);
|
||||
|
||||
return {
|
||||
prompt: 'Prompt deleted successfully',
|
||||
promptGroup: {
|
||||
message: 'Prompt group deleted successfully',
|
||||
id: groupId,
|
||||
},
|
||||
};
|
||||
} else {
|
||||
const promptGroup = await PromptGroup.findById(groupId).lean();
|
||||
if (promptGroup.productionId.toString() === promptId.toString()) {
|
||||
await PromptGroup.updateOne(
|
||||
{ _id: groupId },
|
||||
{ productionId: remainingPrompts[remainingPrompts.length - 1]._id },
|
||||
);
|
||||
}
|
||||
|
||||
return { prompt: 'Prompt deleted successfully' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Update prompt group
|
||||
* @param {Partial<MongoPromptGroup>} filter - Filter to find prompt group
|
||||
* @param {Partial<MongoPromptGroup>} data - Data to update
|
||||
* @returns {Promise<TUpdatePromptGroupResponse>}
|
||||
*/
|
||||
updatePromptGroup: async (filter, data) => {
|
||||
try {
|
||||
const updateOps = {};
|
||||
if (data.removeProjectIds) {
|
||||
for (const projectId of data.removeProjectIds) {
|
||||
await removeGroupIdsFromProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$pull = { projectIds: { $in: data.removeProjectIds } };
|
||||
delete data.removeProjectIds;
|
||||
}
|
||||
|
||||
if (data.projectIds) {
|
||||
for (const projectId of data.projectIds) {
|
||||
await addGroupIdsToProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$addToSet = { projectIds: { $each: data.projectIds } };
|
||||
delete data.projectIds;
|
||||
}
|
||||
|
||||
const updateData = { ...data, ...updateOps };
|
||||
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
|
||||
if (!updatedDoc) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
return updatedDoc;
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt group', error);
|
||||
return { message: 'Error updating prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Function to make a prompt production based on its ID.
|
||||
* @param {String} promptId - The ID of the prompt to make production.
|
||||
* @returns {Object} The result of the production operation.
|
||||
*/
|
||||
makePromptProduction: async (promptId) => {
|
||||
try {
|
||||
const prompt = await Prompt.findById(promptId).lean();
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt not found');
|
||||
}
|
||||
|
||||
await PromptGroup.findByIdAndUpdate(
|
||||
prompt.groupId,
|
||||
{ productionId: prompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.exec();
|
||||
|
||||
return {
|
||||
message: 'Prompt production made successfully',
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error making prompt production', error);
|
||||
return { message: 'Error making prompt production' };
|
||||
}
|
||||
},
|
||||
updatePromptLabels: async (_id, labels) => {
|
||||
try {
|
||||
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
|
||||
if (response.matchedCount === 0) {
|
||||
return { message: 'Prompt not found' };
|
||||
}
|
||||
return { message: 'Prompt labels updated successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt labels', error);
|
||||
return { message: 'Error updating prompt labels' };
|
||||
}
|
||||
},
|
||||
deletePromptGroup: async (_id) => {
|
||||
try {
|
||||
const response = await PromptGroup.deleteOne({ _id });
|
||||
|
||||
if (response.deletedCount === 0) {
|
||||
return { promptGroup: 'Prompt group not found' };
|
||||
}
|
||||
|
||||
await Prompt.deleteMany({ groupId: new ObjectId(_id) });
|
||||
await removeGroupFromAllProjects(_id);
|
||||
return { promptGroup: 'Prompt group deleted successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error deleting prompt group', error);
|
||||
return { message: 'Error deleting prompt group' };
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
|
|||
86
api/models/Role.js
Normal file
86
api/models/Role.js
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
const { SystemRoles, CacheKeys, roleDefaults } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const Role = require('~/models/schema/roleSchema');
|
||||
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<Object>} A plain object representing the role document.
|
||||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
return cachedRole;
|
||||
}
|
||||
let query = Role.findOne({ name: roleName });
|
||||
if (fieldsToSelect) {
|
||||
query = query.select(fieldsToSelect);
|
||||
}
|
||||
let role = await query.lean().exec();
|
||||
|
||||
if (!role && SystemRoles[roleName]) {
|
||||
role = roleDefaults[roleName];
|
||||
role = await new Role(role).save();
|
||||
await cache.set(roleName, role);
|
||||
return role.toObject();
|
||||
}
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to retrieve or create role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Update role values by name.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to update.
|
||||
* @param {Partial<TRole>} updates - The fields to update.
|
||||
* @returns {Promise<TRole>} Updated role document.
|
||||
*/
|
||||
const updateRoleByName = async function (roleName, updates) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const role = await Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
{ new: true, lean: true },
|
||||
)
|
||||
.select('-__v')
|
||||
.lean()
|
||||
.exec();
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to update role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize default roles in the system.
|
||||
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
|
||||
*
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const initializeRoles = async function () {
|
||||
const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER];
|
||||
|
||||
for (const roleName of defaultRoles) {
|
||||
let role = await Role.findOne({ name: roleName }).select('name').lean();
|
||||
if (!role) {
|
||||
role = new Role(roleDefaults[roleName]);
|
||||
await role.save();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getRoleByName,
|
||||
initializeRoles,
|
||||
updateRoleByName,
|
||||
};
|
||||
|
|
@ -22,7 +22,7 @@ module.exports = {
|
|||
return share;
|
||||
} catch (error) {
|
||||
logger.error('[getShare] Error getting share link', error);
|
||||
return { message: 'Error getting share link' };
|
||||
throw new Error('Error getting share link');
|
||||
}
|
||||
},
|
||||
|
||||
|
|
@ -41,17 +41,17 @@ module.exports = {
|
|||
return { sharedLinks: shares, pages: totalPages, pageNumber, pageSize };
|
||||
} catch (error) {
|
||||
logger.error('[getShareByPage] Error getting shares', error);
|
||||
return { message: 'Error getting shares' };
|
||||
throw new Error('Error getting shares');
|
||||
}
|
||||
},
|
||||
|
||||
createSharedLink: async (user, { conversationId, ...shareData }) => {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (share) {
|
||||
return share;
|
||||
}
|
||||
|
||||
try {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (share) {
|
||||
return share;
|
||||
}
|
||||
|
||||
const shareId = crypto.randomUUID();
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...shareData, shareId, messages, user };
|
||||
|
|
@ -60,31 +60,42 @@ module.exports = {
|
|||
upsert: true,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[saveShareMessage] Error saving conversation', error);
|
||||
return { message: 'Error saving conversation' };
|
||||
logger.error('[createSharedLink] Error creating shared link', error);
|
||||
throw new Error('Error creating shared link');
|
||||
}
|
||||
},
|
||||
|
||||
updateSharedLink: async (user, { conversationId, ...shareData }) => {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
try {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
}
|
||||
|
||||
// update messages to the latest
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...shareData, messages, user };
|
||||
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[updateSharedLink] Error updating shared link', error);
|
||||
throw new Error('Error updating shared link');
|
||||
}
|
||||
// update messages to the latest
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...shareData, messages, user };
|
||||
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
},
|
||||
|
||||
deleteSharedLink: async (user, { shareId }) => {
|
||||
const share = await SharedLink.findOne({ shareId, user });
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
try {
|
||||
const share = await SharedLink.findOne({ shareId, user });
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
}
|
||||
return await SharedLink.findOneAndDelete({ shareId, user });
|
||||
} catch (error) {
|
||||
logger.error('[deleteSharedLink] Error deleting shared link', error);
|
||||
throw new Error('Error deleting shared link');
|
||||
}
|
||||
return await SharedLink.findOneAndDelete({ shareId, user });
|
||||
},
|
||||
/**
|
||||
* Deletes all shared links for a specific user.
|
||||
|
|
@ -100,7 +111,7 @@ module.exports = {
|
|||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllSharedLinks] Error deleting shared links', error);
|
||||
return { message: 'Error deleting shared links' };
|
||||
throw new Error('Error deleting shared links');
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
|
|||
19
api/models/schema/categories.js
Normal file
19
api/models/schema/categories.js
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
const mongoose = require('mongoose');
|
||||
const Schema = mongoose.Schema;
|
||||
|
||||
const categoriesSchema = new Schema({
|
||||
label: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
},
|
||||
value: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
},
|
||||
});
|
||||
|
||||
const categories = mongoose.model('categories', categoriesSchema);
|
||||
|
||||
module.exports = { Categories: categories };
|
||||
30
api/models/schema/projectSchema.js
Normal file
30
api/models/schema/projectSchema.js
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
const { Schema } = require('mongoose');
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoProject
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} name - The name of the project
|
||||
* @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project
|
||||
* @property {Date} [createdAt] - Date when the project was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
const projectSchema = new Schema(
|
||||
{
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
promptGroupIds: {
|
||||
type: [Schema.Types.ObjectId],
|
||||
ref: 'PromptGroup',
|
||||
default: [],
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = projectSchema;
|
||||
118
api/models/schema/promptSchema.js
Normal file
118
api/models/schema/promptSchema.js
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const Schema = mongoose.Schema;
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoPromptGroup
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} name - The name of the prompt group
|
||||
* @property {ObjectId} author - The author of the prompt group
|
||||
* @property {ObjectId} [projectId=null] - The project ID of the prompt group
|
||||
* @property {ObjectId} [productionId=null] - The project ID of the prompt group
|
||||
* @property {string} authorName - The name of the author of the prompt group
|
||||
* @property {number} [numberOfGenerations=0] - Number of generations the prompt group has
|
||||
* @property {string} [oneliner=''] - Oneliner description of the prompt group
|
||||
* @property {string} [category=''] - Category of the prompt group
|
||||
* @property {string} [command] - Command for the prompt group
|
||||
* @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
const promptGroupSchema = new Schema(
|
||||
{
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
numberOfGenerations: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
oneliner: {
|
||||
type: String,
|
||||
default: '',
|
||||
},
|
||||
category: {
|
||||
type: String,
|
||||
default: '',
|
||||
index: true,
|
||||
},
|
||||
projectIds: {
|
||||
type: [Schema.Types.ObjectId],
|
||||
ref: 'Project',
|
||||
index: true,
|
||||
},
|
||||
productionId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'Prompt',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
author: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
authorName: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
command: {
|
||||
type: String,
|
||||
index: true,
|
||||
validate: {
|
||||
validator: function (v) {
|
||||
return v === undefined || v === null || v === '' || /^[a-z0-9-]+$/.test(v);
|
||||
},
|
||||
message: (props) =>
|
||||
`${props.value} is not a valid command. Only lowercase alphanumeric characters and highfins (') are allowed.`,
|
||||
},
|
||||
maxlength: [
|
||||
Constants.COMMANDS_MAX_LENGTH,
|
||||
`Command cannot be longer than ${Constants.COMMANDS_MAX_LENGTH} characters`,
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
|
||||
|
||||
const promptSchema = new Schema(
|
||||
{
|
||||
groupId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'PromptGroup',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
author: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
prompt: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
type: {
|
||||
type: String,
|
||||
enum: ['text', 'chat'],
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
const Prompt = mongoose.model('Prompt', promptSchema);
|
||||
|
||||
promptSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
promptGroupSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
|
||||
module.exports = { Prompt, PromptGroup };
|
||||
29
api/models/schema/roleSchema.js
Normal file
29
api/models/schema/roleSchema.js
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const roleSchema = new mongoose.Schema({
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
index: true,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
[Permissions.USE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
[Permissions.CREATE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const Role = mongoose.model('Role', roleSchema);
|
||||
|
||||
module.exports = Role;
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoSession
|
||||
|
|
@ -78,7 +79,7 @@ const userSchema = mongoose.Schema(
|
|||
},
|
||||
role: {
|
||||
type: String,
|
||||
default: 'USER',
|
||||
default: SystemRoles.USER,
|
||||
},
|
||||
googleId: {
|
||||
type: String,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ const tokenValues = {
|
|||
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
|
||||
'claude-3-opus': { prompt: 15, completion: 75 },
|
||||
'claude-3-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
|
||||
'claude-2.1': { prompt: 8, completion: 24 },
|
||||
'claude-2': { prompt: 8, completion: 24 },
|
||||
|
|
|
|||
|
|
@ -48,6 +48,13 @@ describe('getValueKey', () => {
|
|||
expect(getValueKey('gpt-4o-turbo')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-0125')).toBe('gpt-4o');
|
||||
});
|
||||
|
||||
it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => {
|
||||
expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('claude-3-5-sonnet-turbo')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('claude-3-5-sonnet-0125')).toBe('claude-3-5-sonnet');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMultiplier', () => {
|
||||
|
|
|
|||
|
|
@ -70,20 +70,11 @@ const createUser = async (data, disableTTL = true, returnUser = false) => {
|
|||
delete userData.expiresAt;
|
||||
}
|
||||
|
||||
try {
|
||||
const user = await User.create(userData);
|
||||
if (returnUser) {
|
||||
return user.toObject();
|
||||
}
|
||||
return user._id;
|
||||
} catch (error) {
|
||||
if (error.code === 11000) {
|
||||
// Duplicate key error code
|
||||
throw new Error(`User with \`_id\` ${data._id} already exists.`);
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
const user = await User.create(userData);
|
||||
if (returnUser) {
|
||||
return user.toObject();
|
||||
}
|
||||
return user._id;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "0.7.3",
|
||||
"version": "0.7.4-rc1",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
|
|||
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
|
|
@ -18,6 +18,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
logger.debug('[AskController]', { text, conversationId, ...endpointOption });
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
|
|
@ -34,6 +35,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
|
|
@ -74,6 +77,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
|
|
@ -81,7 +85,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
res.on('close', () => {
|
||||
logger.debug('[AskController] Request closed');
|
||||
|
|
@ -108,7 +112,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
};
|
||||
|
|
@ -121,7 +124,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
|
||||
response.endpoint = endpointOption.endpoint;
|
||||
|
||||
const conversation = await getConvo(user, conversationId);
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
|
|
@ -144,7 +147,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
await saveMessage({ ...response, user });
|
||||
}
|
||||
|
||||
await saveMessage(userMessage);
|
||||
if (!client.skipSaveUserMessage) {
|
||||
await saveMessage(userMessage);
|
||||
}
|
||||
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
addTitle(req, {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
|
|||
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const EditController = async (req, res, next, initializeClient) => {
|
||||
|
|
@ -27,6 +27,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
});
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
|
|
@ -40,6 +41,8 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
|
|
@ -73,6 +76,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
|
||||
const getAbortData = () => ({
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
|
|
@ -81,7 +85,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
res.on('close', () => {
|
||||
logger.debug('[EditController] Request closed');
|
||||
|
|
@ -115,12 +119,11 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
});
|
||||
|
||||
const conversation = await getConvo(user, conversationId);
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
const {
|
||||
EModelEndpoint,
|
||||
CacheKeys,
|
||||
defaultAssistantsVersion,
|
||||
SystemRoles,
|
||||
EModelEndpoint,
|
||||
defaultOrderQuery,
|
||||
defaultAssistantsVersion,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initializeClient: initAzureClient,
|
||||
|
|
@ -227,7 +228,7 @@ const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
|||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||
}
|
||||
|
||||
if (req.user.role === 'ADMIN') {
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return body;
|
||||
} else if (!req.app.locals[endpoint]) {
|
||||
return body;
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ const startServer = async () => {
|
|||
passport.use(passportLogin());
|
||||
|
||||
// LDAP Auth
|
||||
if (process.env.LDAP_URL && process.env.LDAP_BIND_DN && process.env.LDAP_USER_SEARCH_BASE) {
|
||||
if (process.env.LDAP_URL && process.env.LDAP_USER_SEARCH_BASE) {
|
||||
passport.use(ldapLogin);
|
||||
}
|
||||
|
||||
|
|
@ -81,6 +81,7 @@ const startServer = async () => {
|
|||
app.use('/api/convos', routes.convos);
|
||||
app.use('/api/presets', routes.presets);
|
||||
app.use('/api/prompts', routes.prompts);
|
||||
app.use('/api/categories', routes.categories);
|
||||
app.use('/api/tokenizer', routes.tokenizer);
|
||||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
|
|
@ -91,6 +92,7 @@ const startServer = async () => {
|
|||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', validateImageRequest, routes.staticRoute);
|
||||
app.use('/api/share', routes.share);
|
||||
app.use('/api/roles', routes.roles);
|
||||
|
||||
app.use((req, res) => {
|
||||
res.sendFile(path.join(app.locals.paths.dist, 'index.html'));
|
||||
|
|
|
|||
|
|
@ -1,31 +1,36 @@
|
|||
const { isAssistantsEndpoint } = require('librechat-data-provider');
|
||||
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const abortControllers = require('./abortControllers');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
const { abortRun } = require('./abortRun');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, conversationId, endpoint } = req.body;
|
||||
|
||||
if (!abortKey && conversationId) {
|
||||
abortKey = conversationId;
|
||||
}
|
||||
let { abortKey, endpoint } = req.body;
|
||||
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
return await abortRun(req, res);
|
||||
}
|
||||
|
||||
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
|
||||
|
||||
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
|
||||
abortKey = conversationId;
|
||||
}
|
||||
|
||||
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey);
|
||||
const { abortController } = abortControllers.get(abortKey) ?? {};
|
||||
if (!abortController) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
const finalEvent = await abortController.abortCompletion();
|
||||
logger.debug('[abortMessage] Aborted request', { abortKey });
|
||||
logger.info('[abortMessage] Aborted request', { abortKey });
|
||||
abortControllers.delete(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
|
|
@ -50,12 +55,35 @@ const handleAbort = () => {
|
|||
};
|
||||
};
|
||||
|
||||
const createAbortController = (req, res, getAbortData) => {
|
||||
const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
const abortController = new AbortController();
|
||||
const { endpointOption } = req.body;
|
||||
const onStart = (userMessage) => {
|
||||
|
||||
abortController.getAbortData = function () {
|
||||
return getAbortData();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId) => {
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
|
||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
|
||||
if (prevRequest && prevRequest?.abortController) {
|
||||
const data = prevRequest.abortController.getAbortData();
|
||||
getReqData({ userMessage: data?.userMessage });
|
||||
const addedAbortKey = `${abortKey}:${responseMessageId}`;
|
||||
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
|
||||
res.on('finish', function () {
|
||||
abortControllers.delete(addedAbortKey);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
abortControllers.set(abortKey, { abortController, ...endpointOption });
|
||||
|
||||
res.on('finish', function () {
|
||||
|
|
@ -65,7 +93,8 @@ const createAbortController = (req, res, getAbortData) => {
|
|||
|
||||
abortController.abortCompletion = async function () {
|
||||
abortController.abort();
|
||||
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData();
|
||||
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
|
||||
getAbortData();
|
||||
const completionTokens = await countTokens(responseData?.text ?? '');
|
||||
const user = req.user.id;
|
||||
|
||||
|
|
@ -89,10 +118,20 @@ const createAbortController = (req, res, getAbortData) => {
|
|||
|
||||
saveMessage({ ...responseMessage, user });
|
||||
|
||||
let conversation;
|
||||
if (userMessagePromise) {
|
||||
const resolved = await userMessagePromise;
|
||||
conversation = resolved?.conversation;
|
||||
}
|
||||
|
||||
if (!conversation) {
|
||||
conversation = await getConvo(req.user.id, conversationId);
|
||||
}
|
||||
|
||||
return {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: responseMessage,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { getAssistant } = require('~/models/Assistant');
|
||||
|
||||
/**
|
||||
|
|
@ -11,7 +12,7 @@ const { getAssistant } = require('~/models/Assistant');
|
|||
* @returns {Promise<void>}
|
||||
*/
|
||||
const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistantId }) => {
|
||||
if (req.user.role === 'ADMIN') {
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -16,7 +17,7 @@ const { logger } = require('~/config');
|
|||
const canDeleteAccount = async (req, res, next = () => {}) => {
|
||||
const { user } = req;
|
||||
const { ALLOW_ACCOUNT_DELETION = true } = process.env;
|
||||
if (user?.role === 'ADMIN' || isEnabled(ALLOW_ACCOUNT_DELETION)) {
|
||||
if (user?.role === SystemRoles.ADMIN || isEnabled(ALLOW_ACCOUNT_DELETION)) {
|
||||
return next();
|
||||
} else {
|
||||
logger.error(`[User] [Delete Account] [User cannot delete account] [User: ${user?.id}]`);
|
||||
|
|
|
|||
|
|
@ -18,10 +18,12 @@ const limiters = require('./limiters');
|
|||
const uaParser = require('./uaParser');
|
||||
const checkBan = require('./checkBan');
|
||||
const noIndex = require('./noIndex');
|
||||
const roles = require('./roles');
|
||||
|
||||
module.exports = {
|
||||
...abortMiddleware,
|
||||
...limiters,
|
||||
...roles,
|
||||
noIndex,
|
||||
checkBan,
|
||||
uaParser,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ const requireLdapAuth = (req, res, next) => {
|
|||
console.log({
|
||||
title: '(requireLdapAuth) Error: No user',
|
||||
});
|
||||
return res.status(422).send(info);
|
||||
return res.status(404).send(info);
|
||||
}
|
||||
req.user = user;
|
||||
next();
|
||||
|
|
|
|||
14
api/server/middleware/roles/checkAdmin.js
Normal file
14
api/server/middleware/roles/checkAdmin.js
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
function checkAdmin(req, res, next) {
|
||||
try {
|
||||
if (req.user.role !== SystemRoles.ADMIN) {
|
||||
return res.status(403).json({ message: 'Forbidden' });
|
||||
}
|
||||
next();
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Internal Server Error' });
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = checkAdmin;
|
||||
52
api/server/middleware/roles/generateCheckAccess.js
Normal file
52
api/server/middleware/roles/generateCheckAccess.js
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
/**
|
||||
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
||||
*
|
||||
* @param {PermissionTypes} permissionType - The type of permission to check.
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check.
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
|
||||
* @returns {Function} Express middleware function.
|
||||
*/
|
||||
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const { user } = req;
|
||||
if (!user) {
|
||||
return res.status(401).json({ message: 'Authorization required' });
|
||||
}
|
||||
|
||||
if (user.role === SystemRoles.ADMIN) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (bodyProps[permission] && req.body) {
|
||||
return bodyProps[permission].some((prop) =>
|
||||
Object.prototype.hasOwnProperty.call(req.body, prop),
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
if (hasAnyPermission) {
|
||||
return next();
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
||||
} catch (error) {
|
||||
return res.status(500).json({ message: `Server error: ${error.message}` });
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = generateCheckAccess;
|
||||
7
api/server/middleware/roles/index.js
Normal file
7
api/server/middleware/roles/index.js
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
const checkAdmin = require('./checkAdmin');
|
||||
const generateCheckAccess = require('./generateCheckAccess');
|
||||
|
||||
module.exports = {
|
||||
checkAdmin,
|
||||
generateCheckAccess,
|
||||
};
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
const { getConvo } = require('../../models');
|
||||
const { getConvo } = require('~/models');
|
||||
|
||||
// Middleware to validate conversationId and user relationship
|
||||
const validateMessageReq = async (req, res, next) => {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const express = require('express');
|
||||
const AskController = require('~/server/controllers/AskController');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/google');
|
||||
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
|
|
@ -20,7 +20,7 @@ router.post(
|
|||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient);
|
||||
await AskController(req, res, next, initializeClient, addTitle);
|
||||
},
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ const express = require('express');
|
|||
const throttle = require('lodash/throttle');
|
||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
||||
const { saveMessage } = require('~/models');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
|
|
@ -41,6 +41,7 @@ router.post(
|
|||
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
|
|
@ -58,6 +59,8 @@ router.post(
|
|||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
|
|
@ -148,18 +151,10 @@ router.post(
|
|||
}
|
||||
};
|
||||
|
||||
const onChainEnd = () => {
|
||||
saveMessage({ ...userMessage, user });
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
|
|
@ -167,12 +162,23 @@ router.post(
|
|||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
try {
|
||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
const onChainEnd = () => {
|
||||
if (!client.skipSaveUserMessage) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
conversationId,
|
||||
|
|
@ -189,7 +195,6 @@ router.post(
|
|||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
plugins,
|
||||
},
|
||||
|
|
@ -205,10 +210,14 @@ router.post(
|
|||
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
|
||||
await saveMessage({ ...response, user });
|
||||
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
title: conversation.title,
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ const {
|
|||
|
||||
const router = express.Router();
|
||||
|
||||
const ldapAuth =
|
||||
!!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
//Local
|
||||
router.post('/logout', requireJwtAuth, logoutController);
|
||||
router.post(
|
||||
|
|
|
|||
15
api/server/routes/categories.js
Normal file
15
api/server/routes/categories.js
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { getCategories } = require('~/models/Categories');
|
||||
|
||||
router.get('/', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const categories = await getCategories();
|
||||
res.status(200).send(categories);
|
||||
} catch (error) {
|
||||
res.status(500).send({ message: 'Failed to retrieve categories', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
const express = require('express');
|
||||
const { defaultSocialLogins } = require('librechat-data-provider');
|
||||
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
|
|
@ -17,13 +19,21 @@ const publicSharedLinksEnabled =
|
|||
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC));
|
||||
|
||||
router.get('/', async function (req, res) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
|
||||
if (cachedStartupConfig) {
|
||||
res.send(cachedStartupConfig);
|
||||
return;
|
||||
}
|
||||
|
||||
const isBirthday = () => {
|
||||
const today = new Date();
|
||||
return today.getMonth() === 1 && today.getDate() === 11;
|
||||
};
|
||||
|
||||
const ldapLoginEnabled =
|
||||
!!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
const instanceProject = await getProjectByName('instance', '_id');
|
||||
|
||||
const ldapLoginEnabled = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
try {
|
||||
/** @type {TStartupConfig} */
|
||||
const payload = {
|
||||
|
|
@ -63,12 +73,14 @@ router.get('/', async function (req, res) {
|
|||
sharedLinksEnabled,
|
||||
publicSharedLinksEnabled,
|
||||
analyticsGtmId: process.env.ANALYTICS_GTM_ID,
|
||||
instanceProjectId: instanceProject._id.toString(),
|
||||
};
|
||||
|
||||
if (typeof process.env.CUSTOM_FOOTER === 'string') {
|
||||
payload.customFooter = process.env.CUSTOM_FOOTER;
|
||||
}
|
||||
|
||||
await cache.set(CacheKeys.STARTUP_CONFIG, payload);
|
||||
return res.status(200).send(payload);
|
||||
} catch (err) {
|
||||
logger.error('Error in startup config', err);
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ const {
|
|||
} = require('~/server/middleware');
|
||||
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { validateTools } = require('~/app');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -49,6 +49,7 @@ router.post(
|
|||
});
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
|
|
@ -68,6 +69,8 @@ router.post(
|
|||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
|
|
@ -103,21 +106,6 @@ router.post(
|
|||
},
|
||||
});
|
||||
|
||||
const onAgentAction = (action, start = false) => {
|
||||
const formattedAction = formatAction(action);
|
||||
plugin.inputs.push(formattedAction);
|
||||
plugin.latest = formattedAction.plugin;
|
||||
if (!start) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||
};
|
||||
|
||||
const onChainEnd = (data) => {
|
||||
let { intermediateSteps: steps } = data;
|
||||
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
||||
|
|
@ -134,6 +122,7 @@ router.post(
|
|||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
|
|
@ -141,12 +130,27 @@ router.post(
|
|||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
try {
|
||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
const onAgentAction = (action, start = false) => {
|
||||
const formattedAction = formatAction(action);
|
||||
plugin.inputs.push(formattedAction);
|
||||
plugin.latest = formattedAction.plugin;
|
||||
if (!start && !client.skipSaveUserMessage) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
generation,
|
||||
|
|
@ -164,7 +168,6 @@ router.post(
|
|||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
plugin,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
|
|
@ -179,10 +182,14 @@ router.post(
|
|||
response.plugin = { ...plugin, loading: false };
|
||||
await saveMessage({ ...response, user });
|
||||
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
title: conversation.title,
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,19 +1,11 @@
|
|||
const express = require('express');
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
createFileLimiters,
|
||||
createTTSLimiters,
|
||||
createSTTLimiters,
|
||||
} = require('~/server/middleware');
|
||||
const { uaParser, checkBan, requireJwtAuth, createFileLimiters } = require('~/server/middleware');
|
||||
const { createMulterInstance } = require('./multer');
|
||||
|
||||
const files = require('./files');
|
||||
const images = require('./images');
|
||||
const avatar = require('./avatar');
|
||||
const stt = require('./stt');
|
||||
const tts = require('./tts');
|
||||
const speech = require('./speech');
|
||||
|
||||
const initialize = async () => {
|
||||
const router = express.Router();
|
||||
|
|
@ -21,11 +13,8 @@ const initialize = async () => {
|
|||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
/* Important: stt/tts routes must be added before the upload limiters */
|
||||
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
|
||||
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
|
||||
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
|
||||
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
|
||||
/* Important: speech route must be added before the upload limiters */
|
||||
router.use('/speech', speech);
|
||||
|
||||
const upload = await createMulterInstance();
|
||||
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
|
||||
|
|
|
|||
10
api/server/routes/files/speech/customConfigSpeech.js
Normal file
10
api/server/routes/files/speech/customConfigSpeech.js
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
|
||||
const { getCustomConfigSpeech } = require('~/server/services/Files/Audio');
|
||||
|
||||
router.get('/get', async (req, res) => {
|
||||
await getCustomConfigSpeech(req, res);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
17
api/server/routes/files/speech/index.js
Normal file
17
api/server/routes/files/speech/index.js
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
const express = require('express');
|
||||
const { createTTSLimiters, createSTTLimiters } = require('~/server/middleware');
|
||||
|
||||
const stt = require('./stt');
|
||||
const tts = require('./tts');
|
||||
const customConfigSpeech = require('./customConfigSpeech');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
|
||||
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
|
||||
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
|
||||
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
|
||||
|
||||
router.use('/config', customConfigSpeech);
|
||||
|
||||
module.exports = router;
|
||||
|
|
@ -19,6 +19,8 @@ const assistants = require('./assistants');
|
|||
const files = require('./files');
|
||||
const staticRoute = require('./static');
|
||||
const share = require('./share');
|
||||
const categories = require('./categories');
|
||||
const roles = require('./roles');
|
||||
|
||||
module.exports = {
|
||||
search,
|
||||
|
|
@ -42,4 +44,6 @@ module.exports = {
|
|||
files,
|
||||
staticRoute,
|
||||
share,
|
||||
categories,
|
||||
roles,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,14 +1,235 @@
|
|||
const express = require('express');
|
||||
const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider');
|
||||
const {
|
||||
getPrompt,
|
||||
getPrompts,
|
||||
savePrompt,
|
||||
deletePrompt,
|
||||
getPromptGroup,
|
||||
getPromptGroups,
|
||||
updatePromptGroup,
|
||||
deletePromptGroup,
|
||||
createPromptGroup,
|
||||
getAllPromptGroups,
|
||||
// updatePromptLabels,
|
||||
makePromptProduction,
|
||||
} = require('~/models/Prompt');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
const { getPrompts } = require('../../models/Prompt');
|
||||
|
||||
const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]);
|
||||
const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
const checkGlobalPromptShare = generateCheckAccess(
|
||||
PermissionTypes.PROMPTS,
|
||||
[Permissions.USE, Permissions.CREATE],
|
||||
{
|
||||
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||
},
|
||||
);
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkPromptAccess);
|
||||
|
||||
/**
|
||||
* Route to get single prompt group by its ID
|
||||
* GET /groups/:groupId
|
||||
*/
|
||||
router.get('/groups/:groupId', async (req, res) => {
|
||||
let groupId = req.params.groupId;
|
||||
const author = req.user.id;
|
||||
|
||||
const query = {
|
||||
_id: groupId,
|
||||
$or: [{ projectIds: { $exists: true, $ne: [], $not: { $size: 0 } } }, { author }],
|
||||
};
|
||||
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.$or;
|
||||
}
|
||||
|
||||
try {
|
||||
const group = await getPromptGroup(query);
|
||||
|
||||
if (!group) {
|
||||
return res.status(404).send({ message: 'Prompt group not found' });
|
||||
}
|
||||
|
||||
res.status(200).send(group);
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
res.status(500).send({ message: 'Error getting prompt group' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Route to fetch all prompt groups
|
||||
* GET /groups
|
||||
*/
|
||||
router.get('/all', async (req, res) => {
|
||||
try {
|
||||
const groups = await getAllPromptGroups(req, {
|
||||
author: req.user._id,
|
||||
});
|
||||
res.status(200).send(groups);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Route to fetch paginated prompt groups with filters
|
||||
* GET /groups
|
||||
*/
|
||||
router.get('/groups', async (req, res) => {
|
||||
try {
|
||||
const filter = req.query;
|
||||
/* Note: The aggregation requires an ObjectId */
|
||||
filter.author = req.user._id;
|
||||
const groups = await getPromptGroups(req, filter);
|
||||
res.status(200).send(groups);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Updates or creates a prompt + promptGroup
|
||||
* @param {object} req
|
||||
* @param {TCreatePrompt} req.body
|
||||
* @param {Express.Response} res
|
||||
*/
|
||||
const createPrompt = async (req, res) => {
|
||||
try {
|
||||
const { prompt, group } = req.body;
|
||||
if (!prompt) {
|
||||
return res.status(400).send({ error: 'Prompt is required' });
|
||||
}
|
||||
|
||||
const saveData = {
|
||||
prompt,
|
||||
group,
|
||||
author: req.user.id,
|
||||
authorName: req.user.name,
|
||||
};
|
||||
|
||||
/** @type {TCreatePromptResponse} */
|
||||
let result;
|
||||
if (group && group.name) {
|
||||
result = await createPromptGroup(saveData);
|
||||
} else {
|
||||
result = await savePrompt(saveData);
|
||||
}
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error saving prompt' });
|
||||
}
|
||||
};
|
||||
|
||||
router.post('/', createPrompt);
|
||||
|
||||
/**
|
||||
* Updates a prompt group
|
||||
* @param {object} req
|
||||
* @param {object} req.params - The request parameters
|
||||
* @param {string} req.params.groupId - The group ID
|
||||
* @param {TUpdatePromptGroupPayload} req.body - The request body
|
||||
* @param {Express.Response} res
|
||||
*/
|
||||
const patchPromptGroup = async (req, res) => {
|
||||
try {
|
||||
const { groupId } = req.params;
|
||||
const author = req.user.id;
|
||||
const filter = { _id: groupId, author };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete filter.author;
|
||||
}
|
||||
const promptGroup = await updatePromptGroup(filter, req.body);
|
||||
res.status(200).send(promptGroup);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error updating prompt group' });
|
||||
}
|
||||
};
|
||||
|
||||
router.patch('/groups/:groupId', checkGlobalPromptShare, patchPromptGroup);
|
||||
|
||||
router.patch('/:promptId/tags/production', checkPromptCreate, async (req, res) => {
|
||||
try {
|
||||
const { promptId } = req.params;
|
||||
const result = await makePromptProduction(promptId);
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error updating prompt production' });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/:promptId', async (req, res) => {
|
||||
const { promptId } = req.params;
|
||||
const author = req.user.id;
|
||||
const query = { _id: promptId, author };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const prompt = await getPrompt(query);
|
||||
res.status(200).send(prompt);
|
||||
});
|
||||
|
||||
router.get('/', async (req, res) => {
|
||||
let filter = {};
|
||||
// const { search } = req.body.arg;
|
||||
// if (!!search) {
|
||||
// filter = { conversationId };
|
||||
// }
|
||||
res.status(200).send(await getPrompts(filter));
|
||||
try {
|
||||
const author = req.user.id;
|
||||
const { groupId } = req.query;
|
||||
const query = { groupId, author };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const prompts = await getPrompts(query);
|
||||
res.status(200).send(prompts);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompts' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Deletes a prompt
|
||||
*
|
||||
* @param {Express.Request} req - The request object.
|
||||
* @param {TDeletePromptVariables} req.params - The request parameters
|
||||
* @param {import('mongoose').ObjectId} req.params.promptId - The prompt ID
|
||||
* @param {Express.Response} res - The response object.
|
||||
* @return {TDeletePromptResponse} A promise that resolves when the prompt is deleted.
|
||||
*/
|
||||
const deletePromptController = async (req, res) => {
|
||||
try {
|
||||
const { promptId } = req.params;
|
||||
const { groupId } = req.query;
|
||||
const author = req.user.id;
|
||||
const query = { promptId, groupId, author, role: req.user.role };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const result = await deletePrompt(query);
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error deleting prompt' });
|
||||
}
|
||||
};
|
||||
|
||||
router.delete('/:promptId', checkPromptCreate, deletePromptController);
|
||||
|
||||
router.delete('/groups/:groupId', checkPromptCreate, async (req, res) => {
|
||||
const { groupId } = req.params;
|
||||
res.status(200).send(await deletePromptGroup(groupId));
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
72
api/server/routes/roles.js
Normal file
72
api/server/routes/roles.js
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
const express = require('express');
|
||||
const {
|
||||
promptPermissionsSchema,
|
||||
PermissionTypes,
|
||||
roleDefaults,
|
||||
SystemRoles,
|
||||
} = require('librechat-data-provider');
|
||||
const { checkAdmin, requireJwtAuth } = require('~/server/middleware');
|
||||
const { updateRoleByName, getRoleByName } = require('~/models/Role');
|
||||
|
||||
const router = express.Router();
|
||||
router.use(requireJwtAuth);
|
||||
|
||||
/**
|
||||
* GET /api/roles/:roleName
|
||||
* Get a specific role by name
|
||||
*/
|
||||
router.get('/:roleName', async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
|
||||
if (req.user.role !== SystemRoles.ADMIN && !roleDefaults[roleName]) {
|
||||
return res.status(403).send({ message: 'Unauthorized' });
|
||||
}
|
||||
|
||||
try {
|
||||
const role = await getRoleByName(roleName, '-_id -__v');
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
res.status(200).send(role);
|
||||
} catch (error) {
|
||||
return res.status(500).send({ message: 'Failed to retrieve role', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* PUT /api/roles/:roleName/prompts
|
||||
* Update prompt permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/prompts', checkAdmin, async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['PROMPTS']} */
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
const parsedUpdates = promptPermissionsSchema.partial().parse(updates);
|
||||
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const mergedUpdates = {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
...role[PermissionTypes.PROMPTS],
|
||||
...parsedUpdates,
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: 'Invalid prompt permissions.', error: error.errors });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
@ -25,12 +25,16 @@ if (allowSharedLinks) {
|
|||
'/:shareId',
|
||||
allowSharedLinksPublic ? (req, res, next) => next() : requireJwtAuth,
|
||||
async (req, res) => {
|
||||
const share = await getSharedMessages(req.params.shareId);
|
||||
try {
|
||||
const share = await getSharedMessages(req.params.shareId);
|
||||
|
||||
if (share) {
|
||||
res.status(200).json(share);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
if (share) {
|
||||
res.status(200).json(share);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error getting shared messages' });
|
||||
}
|
||||
},
|
||||
);
|
||||
|
|
@ -40,47 +44,63 @@ if (allowSharedLinks) {
|
|||
* Shared links
|
||||
*/
|
||||
router.get('/', requireJwtAuth, async (req, res) => {
|
||||
let pageNumber = req.query.pageNumber || 1;
|
||||
pageNumber = parseInt(pageNumber, 10);
|
||||
try {
|
||||
let pageNumber = req.query.pageNumber || 1;
|
||||
pageNumber = parseInt(pageNumber, 10);
|
||||
|
||||
if (isNaN(pageNumber) || pageNumber < 1) {
|
||||
return res.status(400).json({ error: 'Invalid page number' });
|
||||
if (isNaN(pageNumber) || pageNumber < 1) {
|
||||
return res.status(400).json({ error: 'Invalid page number' });
|
||||
}
|
||||
|
||||
let pageSize = req.query.pageSize || 25;
|
||||
pageSize = parseInt(pageSize, 10);
|
||||
|
||||
if (isNaN(pageSize) || pageSize < 1) {
|
||||
return res.status(400).json({ error: 'Invalid page size' });
|
||||
}
|
||||
const isPublic = req.query.isPublic === 'true';
|
||||
res.status(200).send(await getSharedLinks(req.user.id, pageNumber, pageSize, isPublic));
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error getting shared links' });
|
||||
}
|
||||
|
||||
let pageSize = req.query.pageSize || 25;
|
||||
pageSize = parseInt(pageSize, 10);
|
||||
|
||||
if (isNaN(pageSize) || pageSize < 1) {
|
||||
return res.status(400).json({ error: 'Invalid page size' });
|
||||
}
|
||||
const isPublic = req.query.isPublic === 'true';
|
||||
res.status(200).send(await getSharedLinks(req.user.id, pageNumber, pageSize, isPublic));
|
||||
});
|
||||
|
||||
router.post('/', requireJwtAuth, async (req, res) => {
|
||||
const created = await createSharedLink(req.user.id, req.body);
|
||||
if (created) {
|
||||
res.status(200).json(created);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
try {
|
||||
const created = await createSharedLink(req.user.id, req.body);
|
||||
if (created) {
|
||||
res.status(200).json(created);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error creating shared link' });
|
||||
}
|
||||
});
|
||||
|
||||
router.patch('/', requireJwtAuth, async (req, res) => {
|
||||
const updated = await updateSharedLink(req.user.id, req.body);
|
||||
if (updated) {
|
||||
res.status(200).json(updated);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
try {
|
||||
const updated = await updateSharedLink(req.user.id, req.body);
|
||||
if (updated) {
|
||||
res.status(200).json(updated);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error updating shared link' });
|
||||
}
|
||||
});
|
||||
|
||||
router.delete('/:shareId', requireJwtAuth, async (req, res) => {
|
||||
const deleted = await deleteSharedLink(req.user.id, { shareId: req.params.shareId });
|
||||
if (deleted) {
|
||||
res.status(200).json(deleted);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
try {
|
||||
const deleted = await deleteSharedLink(req.user.id, { shareId: req.params.shareId });
|
||||
if (deleted) {
|
||||
res.status(200).json(deleted);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error deleting shared link' });
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ const handleRateLimits = require('./Config/handleRateLimits');
|
|||
const { loadDefaultInterface } = require('./start/interface');
|
||||
const { azureConfigSetup } = require('./start/azureOpenAI');
|
||||
const { loadAndFormatTools } = require('./ToolService');
|
||||
const { initializeRoles } = require('~/models/Role');
|
||||
const paths = require('~/config/paths');
|
||||
|
||||
/**
|
||||
|
|
@ -16,6 +17,7 @@ const paths = require('~/config/paths');
|
|||
* @param {Express.Application} app - The Express application object.
|
||||
*/
|
||||
const AppService = async (app) => {
|
||||
await initializeRoles();
|
||||
/** @type {TCustomConfig}*/
|
||||
const config = (await loadCustomConfig()) ?? {};
|
||||
const configDefaults = getConfigDefaults();
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ jest.mock('./Config/loadCustomConfig', () => {
|
|||
jest.mock('./Files/Firebase/initialize', () => ({
|
||||
initializeFirebase: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Role', () => ({
|
||||
initializeRoles: jest.fn(),
|
||||
}));
|
||||
jest.mock('./ToolService', () => ({
|
||||
loadAndFormatTools: jest.fn().mockReturnValue({
|
||||
ExampleTool: {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const crypto = require('crypto');
|
||||
const bcrypt = require('bcryptjs');
|
||||
const { errorsToString } = require('librechat-data-provider');
|
||||
const { SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const {
|
||||
findUser,
|
||||
countUsers,
|
||||
|
|
@ -62,7 +62,9 @@ const sendVerificationEmail = async (user) => {
|
|||
let verifyToken = crypto.randomBytes(32).toString('hex');
|
||||
const hash = bcrypt.hashSync(verifyToken, 10);
|
||||
|
||||
const verificationLink = `${domains.client}/verify?token=${verifyToken}&email=${encodeURIComponent(user.email)}`;
|
||||
const verificationLink = `${
|
||||
domains.client
|
||||
}/verify?token=${verifyToken}&email=${encodeURIComponent(user.email)}`;
|
||||
await sendEmail({
|
||||
email: user.email,
|
||||
subject: 'Verify your email',
|
||||
|
|
@ -119,9 +121,10 @@ const verifyEmail = async (req) => {
|
|||
/**
|
||||
* Register a new user.
|
||||
* @param {MongoUser} user <email, password, name, username>
|
||||
* @param {Partial<MongoUser>} [additionalData={}]
|
||||
* @returns {Promise<{status: number, message: string, user?: MongoUser}>}
|
||||
*/
|
||||
const registerUser = async (user) => {
|
||||
const registerUser = async (user, additionalData = {}) => {
|
||||
const { error } = registerSchema.safeParse(user);
|
||||
if (error) {
|
||||
const errorMessage = errorsToString(error.errors);
|
||||
|
|
@ -169,13 +172,15 @@ const registerUser = async (user) => {
|
|||
username,
|
||||
name,
|
||||
avatar: null,
|
||||
role: isFirstRegisteredUser ? 'ADMIN' : 'USER',
|
||||
role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER,
|
||||
password: bcrypt.hashSync(password, salt),
|
||||
...additionalData,
|
||||
};
|
||||
|
||||
const emailEnabled = checkEmailConfig();
|
||||
newUserId = await createUser(newUserData, false);
|
||||
if (emailEnabled) {
|
||||
const newUser = await createUser(newUserData, false, true);
|
||||
newUserId = newUser._id;
|
||||
if (emailEnabled && !newUser.emailVerified) {
|
||||
await sendVerificationEmail({
|
||||
_id: newUserId,
|
||||
email,
|
||||
|
|
@ -363,7 +368,9 @@ const resendVerificationEmail = async (req) => {
|
|||
let verifyToken = crypto.randomBytes(32).toString('hex');
|
||||
const hash = bcrypt.hashSync(verifyToken, 10);
|
||||
|
||||
const verificationLink = `${domains.client}/verify?token=${verifyToken}&email=${encodeURIComponent(user.email)}`;
|
||||
const verificationLink = `${
|
||||
domains.client
|
||||
}/verify?token=${verifyToken}&email=${encodeURIComponent(user.email)}`;
|
||||
|
||||
await sendEmail({
|
||||
email: user.email,
|
||||
|
|
|
|||
|
|
@ -76,8 +76,28 @@ Please specify a correct \`imageOutputType\` value (case-sensitive).
|
|||
);
|
||||
}
|
||||
if (!result.success) {
|
||||
i === 0 && logger.error(`Invalid custom config file at ${configPath}`, result.error);
|
||||
i === 0 && i++;
|
||||
let errorMessage = `Invalid custom config file at ${configPath}:
|
||||
${JSON.stringify(result.error, null, 2)}`;
|
||||
|
||||
if (i === 0) {
|
||||
logger.error(errorMessage);
|
||||
const speechError = result.error.errors.find(
|
||||
(err) =>
|
||||
err.code === 'unrecognized_keys' &&
|
||||
(err.message?.includes('stt') || err.message?.includes('tts')),
|
||||
);
|
||||
|
||||
if (speechError) {
|
||||
logger.warn(`
|
||||
The Speech-to-text and Text-to-speech configuration format has recently changed.
|
||||
If you're getting this error, please refer to the latest documentation:
|
||||
|
||||
https://www.librechat.ai/docs/configuration/stt_tts`);
|
||||
}
|
||||
|
||||
i++;
|
||||
}
|
||||
|
||||
return null;
|
||||
} else {
|
||||
logger.info('Custom config file loaded:');
|
||||
|
|
|
|||
58
api/server/services/Endpoints/google/addTitle.js
Normal file
58
api/server/services/Endpoints/google/addTitle.js
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { saveConvo } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
const initializeClient = require('./initializeClient');
|
||||
|
||||
const addTitle = async (req, { text, response, client }) => {
|
||||
const { TITLE_CONVO = 'true' } = process.env ?? {};
|
||||
if (!isEnabled(TITLE_CONVO)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (client.options.titleConvo === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const DEFAULT_TITLE_MODEL = 'gemini-pro';
|
||||
const { GOOGLE_TITLE_MODEL } = process.env ?? {};
|
||||
|
||||
let model = GOOGLE_TITLE_MODEL ?? DEFAULT_TITLE_MODEL;
|
||||
|
||||
if (GOOGLE_TITLE_MODEL === Constants.CURRENT_MODEL) {
|
||||
model = client.options?.modelOptions.model;
|
||||
|
||||
if (client.isVisionModel) {
|
||||
logger.warn(
|
||||
`current_model was specified for Google title request, but the model ${model} cannot process a text-only conversation. Falling back to ${DEFAULT_TITLE_MODEL}`,
|
||||
);
|
||||
|
||||
model = DEFAULT_TITLE_MODEL;
|
||||
}
|
||||
}
|
||||
|
||||
const titleEndpointOptions = {
|
||||
...client.options,
|
||||
modelOptions: { ...client.options?.modelOptions, model: model },
|
||||
attachments: undefined, // After a response, this is set to an empty array which results in an error during setOptions
|
||||
};
|
||||
|
||||
const { client: titleClient } = await initializeClient({
|
||||
req,
|
||||
res: response,
|
||||
endpointOption: titleEndpointOptions,
|
||||
});
|
||||
|
||||
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
|
||||
const key = `${req.user.id}-${response.conversationId}`;
|
||||
|
||||
const title = await titleClient.titleConvo({ text, responseText: response?.text });
|
||||
await titleCache.set(key, title, 120000);
|
||||
await saveConvo(req.user.id, {
|
||||
conversationId: response.conversationId,
|
||||
title,
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = addTitle;
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
const addTitle = require('./addTitle');
|
||||
const buildOptions = require('./buildOptions');
|
||||
const initializeClient = require('./initializeClient');
|
||||
|
||||
module.exports = {
|
||||
// addTitle, // todo
|
||||
addTitle,
|
||||
buildOptions,
|
||||
initializeClient,
|
||||
};
|
||||
|
|
|
|||
52
api/server/services/Files/Audio/getCustomConfigSpeech.js
Normal file
52
api/server/services/Files/Audio/getCustomConfigSpeech.js
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
|
||||
/**
|
||||
* This function retrieves the speechTab settings from the custom configuration
|
||||
* It first fetches the custom configuration
|
||||
* Then, it checks if the custom configuration and the speechTab schema exist
|
||||
* If they do, it sends the speechTab settings as a JSON response
|
||||
* If they don't, it throws an error
|
||||
*
|
||||
* @param {Object} req - The request object
|
||||
* @param {Object} res - The response object
|
||||
* @returns {Promise<void>}
|
||||
* @throws {Error} - If the custom configuration or the speechTab schema is missing, an error is thrown
|
||||
*/
|
||||
async function getCustomConfigSpeech(req, res) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
|
||||
if (!customConfig || !customConfig.speech?.speechTab) {
|
||||
throw new Error('Configuration or speechTab schema is missing');
|
||||
}
|
||||
|
||||
const ttsSchema = customConfig.speech?.speechTab;
|
||||
let settings = {};
|
||||
|
||||
if (ttsSchema.advancedMode !== undefined) {
|
||||
settings.advancedMode = ttsSchema.advancedMode;
|
||||
}
|
||||
|
||||
if (ttsSchema.speechToText) {
|
||||
for (const key in ttsSchema.speechToText) {
|
||||
if (ttsSchema.speechToText[key] !== undefined) {
|
||||
settings[key] = ttsSchema.speechToText[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ttsSchema.textToSpeech) {
|
||||
for (const key in ttsSchema.textToSpeech) {
|
||||
if (ttsSchema.textToSpeech[key] !== undefined) {
|
||||
settings[key] = ttsSchema.textToSpeech[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(200).send(settings);
|
||||
} catch (error) {
|
||||
res.status(200).send();
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = getCustomConfigSpeech;
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
const { logger } = require('~/config');
|
||||
const { TTSProviders } = require('librechat-data-provider');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getProvider } = require('./textToSpeech');
|
||||
|
||||
|
|
@ -16,22 +16,25 @@ async function getVoices(req, res) {
|
|||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
|
||||
if (!customConfig || !customConfig?.tts) {
|
||||
if (!customConfig || !customConfig?.speech?.tts) {
|
||||
throw new Error('Configuration or TTS schema is missing');
|
||||
}
|
||||
|
||||
const ttsSchema = customConfig?.tts;
|
||||
const ttsSchema = customConfig?.speech?.tts;
|
||||
const provider = getProvider(ttsSchema);
|
||||
let voices;
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
case TTSProviders.OPENAI:
|
||||
voices = ttsSchema.openai?.voices;
|
||||
break;
|
||||
case 'elevenlabs':
|
||||
case TTSProviders.AZURE_OPENAI:
|
||||
voices = ttsSchema.azureOpenAI?.voices;
|
||||
break;
|
||||
case TTSProviders.ELEVENLABS:
|
||||
voices = ttsSchema.elevenlabs?.voices;
|
||||
break;
|
||||
case 'localai':
|
||||
case TTSProviders.LOCALAI:
|
||||
voices = ttsSchema.localai?.voices;
|
||||
break;
|
||||
default:
|
||||
|
|
@ -40,8 +43,7 @@ async function getVoices(req, res) {
|
|||
|
||||
res.json(voices);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to get voices: ${error.message}`);
|
||||
res.status(500).json({ error: 'Failed to get voices' });
|
||||
res.status(500).json({ error: `Failed to get voices: ${error.message}` });
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
const getVoices = require('./getVoices');
|
||||
const getCustomConfigSpeech = require('./getCustomConfigSpeech');
|
||||
const textToSpeech = require('./textToSpeech');
|
||||
const speechToText = require('./speechToText');
|
||||
const { updateTokenWebsocket } = require('./webSocket');
|
||||
|
||||
module.exports = {
|
||||
getVoices,
|
||||
getCustomConfigSpeech,
|
||||
speechToText,
|
||||
...textToSpeech,
|
||||
updateTokenWebsocket,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
const axios = require('axios');
|
||||
const { Readable } = require('stream');
|
||||
const { logger } = require('~/config');
|
||||
const axios = require('axios');
|
||||
const { extractEnvVariable, STTProviders } = require('librechat-data-provider');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { extractEnvVariable } = require('librechat-data-provider');
|
||||
const { genAzureEndpoint } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Handle the response from the STT API
|
||||
|
|
@ -24,12 +25,34 @@ async function handleResponse(response) {
|
|||
return response.data.text.trim();
|
||||
}
|
||||
|
||||
function getProvider(sttSchema) {
|
||||
if (sttSchema.openai) {
|
||||
return 'openai';
|
||||
/**
|
||||
* getProviderSchema function
|
||||
* This function takes the customConfig object and returns the name of the provider and its schema
|
||||
* If more than one provider is set or no provider is set, it throws an error
|
||||
*
|
||||
* @param {Object} customConfig - The custom configuration containing the STT schema
|
||||
* @returns {Promise<[string, Object]>} The name of the provider and its schema
|
||||
* @throws {Error} Throws an error if multiple providers are set or no provider is set
|
||||
*/
|
||||
async function getProviderSchema(customConfig) {
|
||||
const sttSchema = customConfig.speech.stt;
|
||||
|
||||
if (!sttSchema) {
|
||||
throw new Error(`No STT schema is set. Did you configure STT in the custom config (librechat.yaml)?
|
||||
|
||||
https://www.librechat.ai/docs/configuration/stt_tts#stt`);
|
||||
}
|
||||
|
||||
throw new Error('Invalid provider');
|
||||
const providers = Object.entries(sttSchema).filter(([, value]) => Object.keys(value).length > 0);
|
||||
|
||||
if (providers.length > 1) {
|
||||
throw new Error('Multiple providers are set. Please set only one provider.');
|
||||
} else if (providers.length === 0) {
|
||||
throw new Error('No provider is set. Please set a provider.');
|
||||
} else {
|
||||
const provider = providers[0][0];
|
||||
return [provider, sttSchema[provider]];
|
||||
}
|
||||
}
|
||||
|
||||
function removeUndefined(obj) {
|
||||
|
|
@ -83,72 +106,63 @@ function openAIProvider(sttSchema, audioReadStream) {
|
|||
}
|
||||
|
||||
/**
|
||||
* This function prepares the necessary data and headers for making a request to the Azure API
|
||||
* It uses the provided request and audio stream to create the request
|
||||
* Prepares the necessary data and headers for making a request to the Azure API.
|
||||
* It uses the provided Speech-to-Text (STT) schema and audio file to create the request.
|
||||
*
|
||||
* @param {Object} req - The request object, which should contain the endpoint in its body
|
||||
* @param {Stream} audioReadStream - The audio data to be transcribed
|
||||
* @param {Object} sttSchema - The STT schema object, which should contain instanceName, deploymentName, apiVersion, and apiKey.
|
||||
* @param {Buffer} audioBuffer - The audio data to be transcribed
|
||||
* @param {Object} audioFile - The audio file object, which should contain originalname, mimetype, and size.
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it returns an array with three null values and logs the error with logger
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request.
|
||||
* If an error occurs, it logs the error with logger and returns an array with three null values.
|
||||
*/
|
||||
function azureProvider(req, audioReadStream) {
|
||||
function azureOpenAIProvider(sttSchema, audioBuffer, audioFile) {
|
||||
try {
|
||||
const { endpoint } = req.body;
|
||||
const azureConfig = req.app.locals[endpoint];
|
||||
const instanceName = sttSchema?.instanceName;
|
||||
const deploymentName = sttSchema?.deploymentName;
|
||||
const apiVersion = sttSchema?.apiVersion;
|
||||
|
||||
if (!azureConfig) {
|
||||
throw new Error(`No configuration found for endpoint: ${endpoint}`);
|
||||
const url =
|
||||
genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: instanceName,
|
||||
azureOpenAIApiDeploymentName: deploymentName,
|
||||
}) +
|
||||
'/audio/transcriptions?api-version=' +
|
||||
apiVersion;
|
||||
|
||||
const apiKey = sttSchema.apiKey ? extractEnvVariable(sttSchema.apiKey) : '';
|
||||
|
||||
if (audioBuffer.byteLength > 25 * 1024 * 1024) {
|
||||
throw new Error('The audio file size exceeds the limit of 25MB');
|
||||
}
|
||||
const acceptedFormats = ['flac', 'mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'ogg', 'wav', 'webm'];
|
||||
const fileFormat = audioFile.mimetype.split('/')[1];
|
||||
if (!acceptedFormats.includes(fileFormat)) {
|
||||
throw new Error(`The audio file format ${fileFormat} is not accepted`);
|
||||
}
|
||||
|
||||
const { apiKey, instanceName, whisperModel, apiVersion } = Object.entries(
|
||||
azureConfig.groupMap,
|
||||
).reduce((acc, [, value]) => {
|
||||
if (acc) {
|
||||
return acc;
|
||||
}
|
||||
const formData = new FormData();
|
||||
|
||||
const whisperKey = Object.keys(value.models).find((modelKey) =>
|
||||
modelKey.startsWith('whisper'),
|
||||
);
|
||||
const audioBlob = new Blob([audioBuffer], { type: audioFile.mimetype });
|
||||
|
||||
if (whisperKey) {
|
||||
return {
|
||||
apiVersion: value.version,
|
||||
apiKey: value.apiKey,
|
||||
instanceName: value.instanceName,
|
||||
whisperModel: value.models[whisperKey]['deploymentName'],
|
||||
};
|
||||
}
|
||||
formData.append('file', audioBlob, audioFile.originalname);
|
||||
|
||||
return null;
|
||||
}, null);
|
||||
let data = formData;
|
||||
|
||||
if (!apiKey || !instanceName || !whisperModel || !apiVersion) {
|
||||
throw new Error('Required Azure configuration values are missing');
|
||||
}
|
||||
|
||||
const baseURL = `https://${instanceName}.openai.azure.com`;
|
||||
|
||||
const url = `${baseURL}/openai/deployments/${whisperModel}/audio/transcriptions?api-version=${apiVersion}`;
|
||||
|
||||
let data = {
|
||||
file: audioReadStream,
|
||||
filename: 'audio.wav',
|
||||
contentType: 'audio/wav',
|
||||
knownLength: audioReadStream.length,
|
||||
};
|
||||
|
||||
const headers = {
|
||||
...data.getHeaders(),
|
||||
let headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'api-key': apiKey,
|
||||
};
|
||||
|
||||
[headers].forEach(removeUndefined);
|
||||
|
||||
if (apiKey) {
|
||||
headers['api-key'] = apiKey;
|
||||
}
|
||||
|
||||
return [url, data, headers];
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while preparing the Azure API STT request: ', error);
|
||||
return [null, null, null];
|
||||
logger.error('An error occurred while preparing the Azure OpenAI API STT request: ', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -176,16 +190,16 @@ async function speechToText(req, res) {
|
|||
const audioReadStream = Readable.from(audioBuffer);
|
||||
audioReadStream.path = 'audio.wav';
|
||||
|
||||
const provider = getProvider(customConfig.stt);
|
||||
const [provider, sttSchema] = await getProviderSchema(customConfig);
|
||||
|
||||
let [url, data, headers] = [];
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
[url, data, headers] = openAIProvider(customConfig.stt, audioReadStream);
|
||||
case STTProviders.OPENAI:
|
||||
[url, data, headers] = openAIProvider(sttSchema, audioReadStream);
|
||||
break;
|
||||
case 'azure':
|
||||
[url, data, headers] = azureProvider(req, audioReadStream);
|
||||
case STTProviders.AZURE_OPENAI:
|
||||
[url, data, headers] = azureOpenAIProvider(sttSchema, audioBuffer, req.file);
|
||||
break;
|
||||
default:
|
||||
throw new Error('Invalid provider');
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
const axios = require('axios');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { extractEnvVariable } = require('librechat-data-provider');
|
||||
const { extractEnvVariable, TTSProviders } = require('librechat-data-provider');
|
||||
const { logger } = require('~/config');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { genAzureEndpoint } = require('~/utils');
|
||||
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
|
||||
/**
|
||||
* getProvider function
|
||||
|
|
@ -91,6 +92,59 @@ function openAIProvider(ttsSchema, input, voice) {
|
|||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates the necessary parameters for making a request to Azure's OpenAI Text-to-Speech API.
|
||||
*
|
||||
* @param {TCustomConfig['tts']['azureOpenAI']} ttsSchema - The TTS schema containing the AzureOpenAI configuration
|
||||
* @param {string} input - The text to be converted to speech
|
||||
* @param {string} voice - The voice to be used for the speech
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it throws an error with a message indicating that the selected voice is not available
|
||||
*/
|
||||
function azureOpenAIProvider(ttsSchema, input, voice) {
|
||||
const instanceName = ttsSchema?.instanceName;
|
||||
const deploymentName = ttsSchema?.deploymentName;
|
||||
const apiVersion = ttsSchema?.apiVersion;
|
||||
|
||||
const url =
|
||||
genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: instanceName,
|
||||
azureOpenAIApiDeploymentName: deploymentName,
|
||||
}) +
|
||||
'/audio/speech?api-version=' +
|
||||
apiVersion;
|
||||
|
||||
const apiKey = ttsSchema.apiKey ? extractEnvVariable(ttsSchema.apiKey) : '';
|
||||
|
||||
if (
|
||||
ttsSchema?.voices &&
|
||||
ttsSchema.voices.length > 0 &&
|
||||
!ttsSchema.voices.includes(voice) &&
|
||||
!ttsSchema.voices.includes('ALL')
|
||||
) {
|
||||
throw new Error(`Voice ${voice} is not available.`);
|
||||
}
|
||||
|
||||
let data = {
|
||||
model: ttsSchema?.model,
|
||||
input,
|
||||
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
[data, headers].forEach(removeUndefined);
|
||||
|
||||
if (apiKey) {
|
||||
headers['api-key'] = apiKey;
|
||||
}
|
||||
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* elevenLabsProvider function
|
||||
* This function prepares the necessary data and headers for making a request to the Eleven Labs TTS
|
||||
|
|
@ -191,8 +245,8 @@ function localAIProvider(ttsSchema, input, voice) {
|
|||
* @returns {Promise<[string, TProviderSchema]>}
|
||||
*/
|
||||
async function getProviderSchema(customConfig) {
|
||||
const provider = getProvider(customConfig.tts);
|
||||
return [provider, customConfig.tts[provider]];
|
||||
const provider = getProvider(customConfig.speech.tts);
|
||||
return [provider, customConfig.speech.tts[provider]];
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -225,13 +279,16 @@ async function getVoice(providerSchema, requestVoice) {
|
|||
async function ttsRequest(provider, ttsSchema, { input, voice, stream = true } = { stream: true }) {
|
||||
let [url, data, headers] = [];
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
case TTSProviders.OPENAI:
|
||||
[url, data, headers] = openAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
case 'elevenlabs':
|
||||
case TTSProviders.AZURE_OPENAI:
|
||||
[url, data, headers] = azureOpenAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
case TTSProviders.ELEVENLABS:
|
||||
[url, data, headers] = elevenLabsProvider(ttsSchema, input, voice, stream);
|
||||
break;
|
||||
case 'localai':
|
||||
case TTSProviders.LOCALAI:
|
||||
[url, data, headers] = localAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -1,31 +0,0 @@
|
|||
let token = '';
|
||||
|
||||
function updateTokenWebsocket(newToken) {
|
||||
console.log('Token:', newToken);
|
||||
token = newToken;
|
||||
}
|
||||
|
||||
function sendTextToWebsocket(ws, onDataReceived) {
|
||||
if (token === '[DONE]') {
|
||||
ws.send(' ');
|
||||
return;
|
||||
}
|
||||
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(token);
|
||||
|
||||
ws.onmessage = function (event) {
|
||||
console.log('Received:', event.data);
|
||||
if (onDataReceived) {
|
||||
onDataReceived(event.data); // Pass the received data to the callback function
|
||||
}
|
||||
};
|
||||
} else {
|
||||
console.error('WebSocket is not open. Ready state is: ' + ws.readyState);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
updateTokenWebsocket,
|
||||
sendTextToWebsocket,
|
||||
};
|
||||
|
|
@ -427,7 +427,7 @@ class StreamRunManager {
|
|||
|
||||
const toolCallDelta = toolCall[toolCall.type];
|
||||
const progressCallback = this.progressCallbacks.get(stepKey);
|
||||
await progressCallback(toolCallDelta);
|
||||
progressCallback(toolCallDelta);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -73,10 +73,6 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] })
|
|||
continue;
|
||||
}
|
||||
|
||||
if (included.size > 0 && !included.has(file)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let toolInstance = null;
|
||||
try {
|
||||
toolInstance = new ToolClass({ override: true });
|
||||
|
|
@ -92,6 +88,14 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] })
|
|||
continue;
|
||||
}
|
||||
|
||||
if (filter.has(toolInstance.name) && included.size === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (included.size > 0 && !included.has(file) && !included.has(toolInstance.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const formattedTool = formatToOpenAIAssistantTool(toolInstance);
|
||||
tools.push(formattedTool);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ const getUserKeyExpiry = async ({ userId, name }) => {
|
|||
if (!keyValue) {
|
||||
return { expiresAt: null };
|
||||
}
|
||||
return { expiresAt: keyValue.expiresAt };
|
||||
return { expiresAt: keyValue.expiresAt || 'never' };
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -108,18 +108,23 @@ const getUserKeyExpiry = async ({ userId, name }) => {
|
|||
* @description This function either updates an existing user key or inserts a new one into the database,
|
||||
* after encrypting the provided value. It sets the provided expiry date for the key.
|
||||
*/
|
||||
const updateUserKey = async ({ userId, name, value, expiresAt }) => {
|
||||
const updateUserKey = async ({ userId, name, value, expiresAt = null }) => {
|
||||
const encryptedValue = encrypt(value);
|
||||
return await Key.findOneAndUpdate(
|
||||
{ userId, name },
|
||||
{
|
||||
userId,
|
||||
name,
|
||||
value: encryptedValue,
|
||||
expiresAt: new Date(expiresAt),
|
||||
},
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
let updateObject = {
|
||||
userId,
|
||||
name,
|
||||
value: encryptedValue,
|
||||
};
|
||||
|
||||
// Only add expiresAt to the update object if it's not null
|
||||
if (expiresAt) {
|
||||
updateObject.expiresAt = new Date(expiresAt);
|
||||
}
|
||||
|
||||
return await Key.findOneAndUpdate({ userId, name }, updateObject, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
}).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ function checkPasswordReset() {
|
|||
|
||||
Please configure email service for secure password reset functionality.
|
||||
|
||||
https://www.librechat.ai/docs/configuration/authentication/password_reset
|
||||
https://www.librechat.ai/docs/configuration/authentication/email
|
||||
|
||||
❗❗❗`,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -12,33 +12,34 @@ const citationRegex = /\[\^\d+?\^]/g;
|
|||
|
||||
const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text);
|
||||
|
||||
const base = { message: true, initial: true };
|
||||
const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
|
||||
let i = 0;
|
||||
let tokens = addSpaceIfNeeded(generation);
|
||||
|
||||
const progressCallback = async (partial, { res, text, bing = false, ...rest }) => {
|
||||
let chunk = partial === text ? '' : partial;
|
||||
tokens += chunk;
|
||||
tokens = tokens.replaceAll('[DONE]', '');
|
||||
const basePayload = Object.assign({}, base, { text: tokens || '' });
|
||||
|
||||
if (bing) {
|
||||
tokens = citeText(tokens, true);
|
||||
const progressCallback = (chunk, { res, ...rest }) => {
|
||||
basePayload.text = basePayload.text + chunk;
|
||||
|
||||
const payload = Object.assign({}, basePayload, rest);
|
||||
sendMessage(res, payload);
|
||||
if (_onProgress) {
|
||||
_onProgress(payload);
|
||||
}
|
||||
if (i === 0) {
|
||||
basePayload.initial = false;
|
||||
}
|
||||
|
||||
const payload = { text: tokens, message: true, initial: i === 0, ...rest };
|
||||
sendMessage(res, { ...payload, text: tokens });
|
||||
_onProgress && _onProgress(payload);
|
||||
i++;
|
||||
};
|
||||
|
||||
const sendIntermediateMessage = (res, payload, extraTokens = '') => {
|
||||
tokens += extraTokens;
|
||||
sendMessage(res, {
|
||||
text: tokens?.length === 0 ? '' : tokens,
|
||||
message: true,
|
||||
initial: i === 0,
|
||||
...payload,
|
||||
});
|
||||
basePayload.text = basePayload.text + extraTokens;
|
||||
const message = Object.assign({}, basePayload, payload);
|
||||
sendMessage(res, message);
|
||||
if (i === 0) {
|
||||
basePayload.initial = false;
|
||||
}
|
||||
i++;
|
||||
};
|
||||
|
||||
|
|
@ -47,7 +48,7 @@ const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
|
|||
};
|
||||
|
||||
const getPartialText = () => {
|
||||
return tokens;
|
||||
return basePayload.text;
|
||||
};
|
||||
|
||||
return { onProgress, getPartialText, sendIntermediateMessage };
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||
const { getUserById } = require('~/models');
|
||||
const { getUserById, updateUser } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
// JWT strategy
|
||||
|
|
@ -14,6 +15,10 @@ const jwtLogin = async () =>
|
|||
const user = await getUserById(payload?.id, '-password -__v');
|
||||
if (user) {
|
||||
user.id = user._id.toString();
|
||||
if (!user.role) {
|
||||
user.role = SystemRoles.USER;
|
||||
await updateUser(user.id, { role: user.role });
|
||||
}
|
||||
done(null, user);
|
||||
} else {
|
||||
logger.warn('[jwtLogin] JwtStrategy => no user found: ' + payload?.id);
|
||||
|
|
|
|||
|
|
@ -1,17 +1,66 @@
|
|||
const fs = require('fs');
|
||||
const LdapStrategy = require('passport-ldapauth');
|
||||
const { findUser, createUser, updateUser } = require('~/models/userMethods');
|
||||
const fs = require('fs');
|
||||
const logger = require('~/utils/logger');
|
||||
|
||||
const {
|
||||
LDAP_URL,
|
||||
LDAP_BIND_DN,
|
||||
LDAP_BIND_CREDENTIALS,
|
||||
LDAP_USER_SEARCH_BASE,
|
||||
LDAP_SEARCH_FILTER,
|
||||
LDAP_CA_CERT_PATH,
|
||||
LDAP_FULL_NAME,
|
||||
LDAP_ID,
|
||||
LDAP_USERNAME,
|
||||
} = process.env;
|
||||
|
||||
// Check required environment variables
|
||||
if (!LDAP_URL || !LDAP_USER_SEARCH_BASE) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const searchAttributes = [
|
||||
'displayName',
|
||||
'mail',
|
||||
'uid',
|
||||
'cn',
|
||||
'name',
|
||||
'commonname',
|
||||
'givenName',
|
||||
'sn',
|
||||
'sAMAccountName',
|
||||
];
|
||||
|
||||
if (LDAP_FULL_NAME) {
|
||||
searchAttributes.push(...LDAP_FULL_NAME.split(','));
|
||||
}
|
||||
if (LDAP_ID) {
|
||||
searchAttributes.push(LDAP_ID);
|
||||
}
|
||||
if (LDAP_USERNAME) {
|
||||
searchAttributes.push(LDAP_USERNAME);
|
||||
}
|
||||
|
||||
const ldapOptions = {
|
||||
server: {
|
||||
url: process.env.LDAP_URL,
|
||||
bindDN: process.env.LDAP_BIND_DN,
|
||||
bindCredentials: process.env.LDAP_BIND_CREDENTIALS,
|
||||
searchBase: process.env.LDAP_USER_SEARCH_BASE,
|
||||
searchFilter: process.env.LDAP_SEARCH_FILTER || 'mail={{username}}',
|
||||
searchAttributes: ['displayName', 'mail', 'uid', 'cn', 'name', 'commonname', 'givenName', 'sn'],
|
||||
...(process.env.LDAP_CA_CERT_PATH && {
|
||||
tlsOptions: { ca: [fs.readFileSync(process.env.LDAP_CA_CERT_PATH)] },
|
||||
url: LDAP_URL,
|
||||
bindDN: LDAP_BIND_DN,
|
||||
bindCredentials: LDAP_BIND_CREDENTIALS,
|
||||
searchBase: LDAP_USER_SEARCH_BASE,
|
||||
searchFilter: LDAP_SEARCH_FILTER || 'mail={{username}}',
|
||||
searchAttributes: [...new Set(searchAttributes)],
|
||||
...(LDAP_CA_CERT_PATH && {
|
||||
tlsOptions: {
|
||||
ca: (() => {
|
||||
try {
|
||||
return [fs.readFileSync(LDAP_CA_CERT_PATH)];
|
||||
} catch (err) {
|
||||
logger.error('[ldapStrategy]', 'Failed to read CA certificate', err);
|
||||
throw err;
|
||||
}
|
||||
})(),
|
||||
},
|
||||
}),
|
||||
},
|
||||
usernameField: 'email',
|
||||
|
|
@ -23,45 +72,55 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
|||
return done(null, false, { message: 'Invalid credentials' });
|
||||
}
|
||||
|
||||
try {
|
||||
const firstName = userinfo.givenName;
|
||||
const familyName = userinfo.surname || userinfo.sn;
|
||||
const fullName =
|
||||
firstName && familyName
|
||||
? `${firstName} ${familyName}`
|
||||
: userinfo.cn ||
|
||||
userinfo.name ||
|
||||
userinfo.commonname ||
|
||||
userinfo.displayName ||
|
||||
userinfo.mail;
|
||||
if (!userinfo.mail) {
|
||||
logger.warn(
|
||||
'[ldapStrategy]',
|
||||
'No email attributes found in userinfo',
|
||||
JSON.stringify(userinfo, null, 2),
|
||||
);
|
||||
return done(null, false, { message: 'Invalid credentials' });
|
||||
}
|
||||
|
||||
try {
|
||||
const ldapId =
|
||||
(LDAP_ID && userinfo[LDAP_ID]) || userinfo.uid || userinfo.sAMAccountName || userinfo.mail;
|
||||
|
||||
let user = await findUser({ ldapId });
|
||||
|
||||
const fullNameAttributes = LDAP_FULL_NAME && LDAP_FULL_NAME.split(',');
|
||||
const fullName =
|
||||
fullNameAttributes && fullNameAttributes.length > 0
|
||||
? fullNameAttributes.map((attr) => userinfo[attr]).join(' ')
|
||||
: userinfo.cn || userinfo.name || userinfo.commonname || userinfo.displayName;
|
||||
|
||||
const username =
|
||||
(LDAP_USERNAME && userinfo[LDAP_USERNAME]) || userinfo.givenName || userinfo.mail;
|
||||
|
||||
const username = userinfo.givenName || userinfo.mail;
|
||||
let user = await findUser({ email: userinfo.mail });
|
||||
if (user && user.provider !== 'ldap') {
|
||||
return done(null, false, { message: 'Invalid credentials' });
|
||||
}
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'ldap',
|
||||
ldapId: userinfo.uid,
|
||||
ldapId,
|
||||
username,
|
||||
email: userinfo.mail || '',
|
||||
emailVerified: true,
|
||||
email: userinfo.mail,
|
||||
emailVerified: true, // The ldap server administrator should verify the email
|
||||
name: fullName,
|
||||
};
|
||||
const userId = await createUser(user);
|
||||
user._id = userId;
|
||||
} else {
|
||||
// Users registered in LDAP are assumed to have their user information managed in LDAP,
|
||||
// so update the user information with the values registered in LDAP
|
||||
user.provider = 'ldap';
|
||||
user.ldapId = userinfo.uid;
|
||||
user.ldapId = ldapId;
|
||||
user.email = userinfo.mail;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
|
||||
done(null, user);
|
||||
} catch (err) {
|
||||
logger.error('[ldapStrategy]', err);
|
||||
done(err);
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -25,12 +25,18 @@ const downloadImage = async (url, accessToken) => {
|
|||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
const options = {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
options.agent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
const response = await fetch(url, options);
|
||||
|
||||
if (response.ok) {
|
||||
const buffer = await response.buffer();
|
||||
|
|
|
|||
123
api/typedefs.js
123
api/typedefs.js
|
|
@ -248,6 +248,110 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/** Prompts */
|
||||
/**
|
||||
* @exports TPrompt
|
||||
* @typedef {import('librechat-data-provider').TPrompt} TPrompt
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TPromptGroup
|
||||
* @typedef {import('librechat-data-provider').TPromptGroup} TPromptGroup
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TCreatePrompt
|
||||
* @typedef {import('librechat-data-provider').TCreatePrompt} TCreatePrompt
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TCreatePromptRecord
|
||||
* @typedef {import('librechat-data-provider').TCreatePromptRecord} TCreatePromptRecord
|
||||
* @memberof typedefs
|
||||
*/
|
||||
/**
|
||||
* @exports TCreatePromptResponse
|
||||
* @typedef {import('librechat-data-provider').TCreatePromptResponse} TCreatePromptResponse
|
||||
* @memberof typedefs
|
||||
*/
|
||||
/**
|
||||
* @exports TUpdatePromptGroupResponse
|
||||
* @typedef {import('librechat-data-provider').TUpdatePromptGroupResponse} TUpdatePromptGroupResponse
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TPromptGroupsWithFilterRequest
|
||||
* @typedef {import('librechat-data-provider').TPromptGroupsWithFilterRequest } TPromptGroupsWithFilterRequest
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports PromptGroupListResponse
|
||||
* @typedef {import('librechat-data-provider').PromptGroupListResponse } PromptGroupListResponse
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TGetCategoriesResponse
|
||||
* @typedef {import('librechat-data-provider').TGetCategoriesResponse } TGetCategoriesResponse
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TGetRandomPromptsResponse
|
||||
* @typedef {import('librechat-data-provider').TGetRandomPromptsResponse } TGetRandomPromptsResponse
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TGetRandomPromptsRequest
|
||||
* @typedef {import('librechat-data-provider').TGetRandomPromptsRequest } TGetRandomPromptsRequest
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TUpdatePromptGroupPayload
|
||||
* @typedef {import('librechat-data-provider').TUpdatePromptGroupPayload } TUpdatePromptGroupPayload
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TDeletePromptVariables
|
||||
* @typedef {import('librechat-data-provider').TDeletePromptVariables } TDeletePromptVariables
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TDeletePromptResponse
|
||||
* @typedef {import('librechat-data-provider').TDeletePromptResponse } TDeletePromptResponse
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/* Roles */
|
||||
|
||||
/**
|
||||
* @exports TRole
|
||||
* @typedef {import('librechat-data-provider').TRole } TRole
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports PermissionTypes
|
||||
* @typedef {import('librechat-data-provider').PermissionTypes } PermissionTypes
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports Permissions
|
||||
* @typedef {import('librechat-data-provider').Permissions } Permissions
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/** Assistants */
|
||||
/**
|
||||
* @exports Assistant
|
||||
* @typedef {import('librechat-data-provider').Assistant} Assistant
|
||||
|
|
@ -500,6 +604,18 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports MongoProject
|
||||
* @typedef {import('~/models/schema/projectSchema.js').MongoProject} MongoProject
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports MongoPromptGroup
|
||||
* @typedef {import('~/models/schema/promptSchema.js').MongoPromptGroup} MongoPromptGroup
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports uploadImageBuffer
|
||||
* @typedef {import('~/server/services/Files/process').uploadImageBuffer} uploadImageBuffer
|
||||
|
|
@ -1326,3 +1442,10 @@
|
|||
* @typedef {import('librechat-data-provider').TForkConvoRequest} TForkConvoRequest
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/** Clients */
|
||||
|
||||
/**
|
||||
* @typedef {Promise<{ message: TMessage, conversation: TConversation }> | undefined} ClientDatabaseSavePromise
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ const anthropicModels = {
|
|||
'claude-3-haiku': 200000,
|
||||
'claude-3-sonnet': 200000,
|
||||
'claude-3-opus': 200000,
|
||||
'claude-3-5-sonnet': 200000,
|
||||
};
|
||||
|
||||
const aggregateModels = { ...openAIModels, ...googleModels, ...anthropicModels, ...cohereModels };
|
||||
|
|
|
|||
|
|
@ -124,12 +124,29 @@ describe('getModelMaxTokens', () => {
|
|||
'claude-1-100k',
|
||||
'claude-instant-1',
|
||||
'claude-instant-1-100k',
|
||||
'claude-3-haiku',
|
||||
'claude-3-sonnet',
|
||||
'claude-3-opus',
|
||||
'claude-3-5-sonnet',
|
||||
];
|
||||
|
||||
const claudeMaxTokens = maxTokensMap[EModelEndpoint.anthropic]['claude-'];
|
||||
const claude21MaxTokens = maxTokensMap[EModelEndpoint.anthropic]['claude-2.1'];
|
||||
const maxTokens = {
|
||||
'claude-': maxTokensMap[EModelEndpoint.anthropic]['claude-'],
|
||||
'claude-2.1': maxTokensMap[EModelEndpoint.anthropic]['claude-2.1'],
|
||||
'claude-3': maxTokensMap[EModelEndpoint.anthropic]['claude-3-sonnet'],
|
||||
};
|
||||
|
||||
models.forEach((model) => {
|
||||
const expectedTokens = model === 'claude-2.1' ? claude21MaxTokens : claudeMaxTokens;
|
||||
let expectedTokens;
|
||||
|
||||
if (model === 'claude-2.1') {
|
||||
expectedTokens = maxTokens['claude-2.1'];
|
||||
} else if (model.startsWith('claude-3')) {
|
||||
expectedTokens = maxTokens['claude-3'];
|
||||
} else {
|
||||
expectedTokens = maxTokens['claude-'];
|
||||
}
|
||||
|
||||
expect(getModelMaxTokens(model, EModelEndpoint.anthropic)).toEqual(expectedTokens);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -10,6 +10,15 @@ server {
|
|||
######################################## Non-SSL ########################################
|
||||
server_name localhost;
|
||||
|
||||
# https://docs.nginx.com/nginx/admin-guide/web-server/compression/
|
||||
# gzip on;
|
||||
# gzip_vary on;
|
||||
# gzip_proxied any;
|
||||
# gzip_comp_level 6;
|
||||
# gzip_buffers 16 8k;
|
||||
# gzip_http_version 1.1;
|
||||
# gzip_types text/css application/javascript application/json application/octet-stream;
|
||||
|
||||
# Increase the client_max_body_size to allow larger file uploads
|
||||
# The default limits for image uploads as of 11/22/23 is 20MB/file, and 25MB/request
|
||||
client_max_body_size 25M;
|
||||
|
|
@ -33,6 +42,15 @@ server {
|
|||
# listen 443 ssl http2;
|
||||
# listen [::]:443 ssl http2;
|
||||
|
||||
# https://docs.nginx.com/nginx/admin-guide/web-server/compression/
|
||||
# gzip on;
|
||||
# gzip_vary on;
|
||||
# gzip_proxied any;
|
||||
# gzip_comp_level 6;
|
||||
# gzip_buffers 16 8k;
|
||||
# gzip_http_version 1.1;
|
||||
# gzip_types text/css application/javascript application/json application/octet-stream;
|
||||
|
||||
# ssl_certificate /etc/nginx/ssl/nginx.crt;
|
||||
# ssl_certificate_key /etc/nginx/ssl/nginx.key;
|
||||
# ssl_session_timeout 1d;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@librechat/frontend",
|
||||
"version": "0.7.3",
|
||||
"version": "0.7.4-rc1",
|
||||
"description": "",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
|
@ -31,7 +31,7 @@
|
|||
"@ariakit/react": "^0.4.5",
|
||||
"@dicebear/collection": "^7.0.4",
|
||||
"@dicebear/core": "^7.0.4",
|
||||
"@headlessui/react": "^1.7.13",
|
||||
"@headlessui/react": "^2.1.2",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-alert-dialog": "^1.0.2",
|
||||
"@radix-ui/react-checkbox": "^1.0.3",
|
||||
|
|
@ -66,10 +66,11 @@
|
|||
"image-blob-reduce": "^4.1.0",
|
||||
"librechat-data-provider": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.220.0",
|
||||
"lucide-react": "^0.394.0",
|
||||
"match-sorter": "^6.3.4",
|
||||
"rc-input-number": "^7.4.2",
|
||||
"react": "^18.2.0",
|
||||
"react-avatar-editor": "^13.0.2",
|
||||
"react-dnd": "^16.0.1",
|
||||
"react-dnd-html5-backend": "^16.0.1",
|
||||
"react-dom": "^18.2.0",
|
||||
|
|
|
|||
6
client/src/Providers/AddedChatContext.tsx
Normal file
6
client/src/Providers/AddedChatContext.tsx
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import { createContext, useContext } from 'react';
|
||||
import useAddedResponse from '~/hooks/Chat/useAddedResponse';
|
||||
type TAddedChatContext = ReturnType<typeof useAddedResponse>;
|
||||
|
||||
export const AddedChatContext = createContext<TAddedChatContext>({} as TAddedChatContext);
|
||||
export const useAddedChatContext = () => useContext(AddedChatContext);
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import { createContext, useContext } from 'react';
|
||||
import useChatHelpers from '~/hooks/useChatHelpers';
|
||||
import useChatHelpers from '~/hooks/Chat/useChatHelpers';
|
||||
type TChatContext = ReturnType<typeof useChatHelpers>;
|
||||
|
||||
export const ChatContext = createContext<TChatContext>({} as TChatContext);
|
||||
|
|
|
|||
6
client/src/Providers/ChatFormContext.tsx
Normal file
6
client/src/Providers/ChatFormContext.tsx
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import { createFormContext } from './CustomFormContext';
|
||||
import type { ChatFormValues } from '~/common';
|
||||
|
||||
const { CustomFormProvider, useCustomFormContext } = createFormContext<ChatFormValues>();
|
||||
|
||||
export { CustomFormProvider as ChatFormProvider, useCustomFormContext as useChatFormContext };
|
||||
56
client/src/Providers/CustomFormContext.tsx
Normal file
56
client/src/Providers/CustomFormContext.tsx
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import React, { createContext, PropsWithChildren, ReactElement, useContext, useMemo } from 'react';
|
||||
import type {
|
||||
Control,
|
||||
// FieldErrors,
|
||||
FieldValues,
|
||||
UseFormReset,
|
||||
UseFormRegister,
|
||||
UseFormGetValues,
|
||||
UseFormHandleSubmit,
|
||||
UseFormSetValue,
|
||||
} from 'react-hook-form';
|
||||
|
||||
interface FormContextValue<TFieldValues extends FieldValues> {
|
||||
register: UseFormRegister<TFieldValues>;
|
||||
control: Control<TFieldValues>;
|
||||
// errors: FieldErrors<TFieldValues>;
|
||||
getValues: UseFormGetValues<TFieldValues>;
|
||||
setValue: UseFormSetValue<TFieldValues>;
|
||||
handleSubmit: UseFormHandleSubmit<TFieldValues>;
|
||||
reset: UseFormReset<TFieldValues>;
|
||||
}
|
||||
|
||||
function createFormContext<TFieldValues extends FieldValues>() {
|
||||
const context = createContext<FormContextValue<TFieldValues> | undefined>(undefined);
|
||||
|
||||
const useCustomFormContext = (): FormContextValue<TFieldValues> => {
|
||||
const value = useContext(context);
|
||||
if (!value) {
|
||||
throw new Error('useCustomFormContext must be used within a CustomFormProvider');
|
||||
}
|
||||
return value;
|
||||
};
|
||||
|
||||
const CustomFormProvider = ({
|
||||
register,
|
||||
control,
|
||||
setValue,
|
||||
// errors,
|
||||
getValues,
|
||||
handleSubmit,
|
||||
reset,
|
||||
children,
|
||||
}: PropsWithChildren<FormContextValue<TFieldValues>>): ReactElement => {
|
||||
const value = useMemo(
|
||||
() => ({ register, control, getValues, setValue, handleSubmit, reset }),
|
||||
[register, control, setValue, getValues, handleSubmit, reset],
|
||||
);
|
||||
|
||||
return <context.Provider value={value}>{children}</context.Provider>;
|
||||
};
|
||||
|
||||
return { CustomFormProvider, useCustomFormContext };
|
||||
}
|
||||
|
||||
export type { FormContextValue };
|
||||
export { createFormContext };
|
||||
7
client/src/Providers/DashboardContext.tsx
Normal file
7
client/src/Providers/DashboardContext.tsx
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
import { createContext, useContext } from 'react';
|
||||
type TDashboardContext = {
|
||||
prevLocationPath: string;
|
||||
};
|
||||
|
||||
export const DashboardContext = createContext<TDashboardContext>({} as TDashboardContext);
|
||||
export const useDashboardContext = () => useContext(DashboardContext);
|
||||
|
|
@ -5,5 +5,8 @@ export * from './ShareContext';
|
|||
export * from './ToastContext';
|
||||
export * from './SearchContext';
|
||||
export * from './FileMapContext';
|
||||
export * from './AddedChatContext';
|
||||
export * from './ChatFormContext';
|
||||
export * from './DashboardContext';
|
||||
export * from './AssistantsContext';
|
||||
export * from './AssistantsMapContext';
|
||||
|
|
|
|||
|
|
@ -1,16 +1,20 @@
|
|||
import React from 'react';
|
||||
import { FileSources } from 'librechat-data-provider';
|
||||
import type * as InputNumberPrimitive from 'rc-input-number';
|
||||
import type { ColumnDef } from '@tanstack/react-table';
|
||||
import type { SetterOrUpdater } from 'recoil';
|
||||
import type {
|
||||
TRole,
|
||||
TUser,
|
||||
Action,
|
||||
TPreset,
|
||||
TPlugin,
|
||||
TMessage,
|
||||
Assistant,
|
||||
TResPlugin,
|
||||
TLoginUser,
|
||||
AuthTypeEnum,
|
||||
TModelsConfig,
|
||||
TConversation,
|
||||
TStartupConfig,
|
||||
EModelEndpoint,
|
||||
|
|
@ -52,6 +56,8 @@ export type LastSelectedModels = Record<EModelEndpoint, string>;
|
|||
|
||||
export type LocalizeFunction = (phraseKey: string, ...values: string[]) => string;
|
||||
|
||||
export type ChatFormValues = { text: string };
|
||||
|
||||
export const mainTextareaId = 'prompt-textarea';
|
||||
export const globalAudioId = 'global-audio';
|
||||
|
||||
|
|
@ -75,7 +81,7 @@ export type IconMapProps = {
|
|||
export type NavLink = {
|
||||
title: string;
|
||||
label?: string;
|
||||
icon: LucideIcon;
|
||||
icon: LucideIcon | React.FC;
|
||||
Component?: React.ComponentType;
|
||||
onClick?: () => void;
|
||||
variant?: 'default' | 'ghost';
|
||||
|
|
@ -225,6 +231,8 @@ export type TGenButtonProps = {
|
|||
|
||||
export type TAskProps = {
|
||||
text: string;
|
||||
overrideConvoId?: string;
|
||||
overrideUserMessageId?: string;
|
||||
parentMessageId?: string | null;
|
||||
conversationId?: string | null;
|
||||
messageId?: string | null;
|
||||
|
|
@ -237,6 +245,7 @@ export type TOptions = {
|
|||
isRegenerate?: boolean;
|
||||
isContinued?: boolean;
|
||||
isEdited?: boolean;
|
||||
overrideMessages?: TMessage[];
|
||||
};
|
||||
|
||||
export type TAskFunction = (props: TAskProps, options?: TOptions) => void;
|
||||
|
|
@ -299,6 +308,7 @@ export type TDangerButtonProps = {
|
|||
actionTextCode: string;
|
||||
dataTestIdInitial: string;
|
||||
dataTestIdConfirm: string;
|
||||
infoDescriptionCode?: string;
|
||||
confirmActionTextCode?: string;
|
||||
};
|
||||
|
||||
|
|
@ -325,6 +335,7 @@ export type TAuthContext = {
|
|||
login: (data: TLoginUser) => void;
|
||||
logout: () => void;
|
||||
setError: React.Dispatch<React.SetStateAction<string | undefined>>;
|
||||
roles?: Record<string, TRole | null | undefined>;
|
||||
};
|
||||
|
||||
export type TUserContext = {
|
||||
|
|
@ -364,6 +375,9 @@ export type MentionOption = OptionWithIcon & {
|
|||
value: string;
|
||||
description?: string;
|
||||
};
|
||||
export type PromptOption = MentionOption & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type TOptionSettings = {
|
||||
showExamples?: boolean;
|
||||
|
|
@ -394,7 +408,6 @@ export interface SwitcherProps {
|
|||
endpointKeyProvided: boolean;
|
||||
isCollapsed: boolean;
|
||||
}
|
||||
|
||||
export type TLoginLayoutContext = {
|
||||
startupConfig: TStartupConfig | null;
|
||||
startupConfigError: unknown;
|
||||
|
|
@ -404,3 +417,42 @@ export type TLoginLayoutContext = {
|
|||
headerText: string;
|
||||
setHeaderText: React.Dispatch<React.SetStateAction<string>>;
|
||||
};
|
||||
|
||||
export type NewConversationParams = {
|
||||
template?: Partial<TConversation>;
|
||||
preset?: Partial<TPreset>;
|
||||
modelsData?: TModelsConfig;
|
||||
buildDefault?: boolean;
|
||||
keepLatestMessage?: boolean;
|
||||
keepAddedConvos?: boolean;
|
||||
};
|
||||
|
||||
export type ConvoGenerator = (params: NewConversationParams) => void | TConversation;
|
||||
|
||||
export type TResData = {
|
||||
plugin?: TResPlugin;
|
||||
final?: boolean;
|
||||
initial?: boolean;
|
||||
previousMessages?: TMessage[];
|
||||
requestMessage: TMessage;
|
||||
responseMessage: TMessage;
|
||||
conversation: TConversation;
|
||||
conversationId?: string;
|
||||
runMessages?: TMessage[];
|
||||
};
|
||||
export type TVectorStore = {
|
||||
_id: string;
|
||||
object: 'vector_store';
|
||||
created_at: string | Date;
|
||||
name: string;
|
||||
bytes?: number;
|
||||
file_counts?: {
|
||||
in_progress: number;
|
||||
completed: number;
|
||||
failed: number;
|
||||
cancelled: number;
|
||||
total: number;
|
||||
};
|
||||
};
|
||||
|
||||
export type TThread = { id: string; createdAt: string };
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import { ThemeSelector } from '~/components/ui';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { BlinkAnimation } from './BlinkAnimation';
|
||||
import { TStartupConfig } from 'librechat-data-provider';
|
||||
import SocialLoginRender from './SocialLoginRender';
|
||||
import { ThemeSelector } from '~/components/ui';
|
||||
import Footer from './Footer';
|
||||
|
||||
const ErrorRender = ({ children }: { children: React.ReactNode }) => (
|
||||
|
|
|
|||
47
client/src/components/Chat/AddMultiConvo.tsx
Normal file
47
client/src/components/Chat/AddMultiConvo.tsx
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import { PlusCircle } from 'lucide-react';
|
||||
import { isAssistantsEndpoint } from 'librechat-data-provider';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { mainTextareaId } from '~/common';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
function AddMultiConvo({ className = '' }: { className?: string }) {
|
||||
const { conversation } = useChatContext();
|
||||
const { setConversation: setAddedConvo } = useAddedChatContext();
|
||||
|
||||
const clickHandler = () => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { title: _t, ...convo } = conversation ?? ({} as TConversation);
|
||||
setAddedConvo({
|
||||
...convo,
|
||||
title: '',
|
||||
});
|
||||
|
||||
const textarea = document.getElementById(mainTextareaId);
|
||||
if (textarea) {
|
||||
textarea.focus();
|
||||
}
|
||||
};
|
||||
|
||||
if (!conversation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isAssistantsEndpoint(conversation.endpoint)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={clickHandler}
|
||||
className={cn(
|
||||
'group m-1.5 flex w-fit cursor-pointer items-center rounded text-sm hover:bg-border-medium focus-visible:bg-border-medium focus-visible:outline-0',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<PlusCircle size={16} />
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
export default AddMultiConvo;
|
||||
|
|
@ -1,10 +1,12 @@
|
|||
import { memo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useParams } from 'react-router-dom';
|
||||
import { useGetMessagesByConvoId } from 'librechat-data-provider/react-query';
|
||||
import { ChatContext, useFileMapContext } from '~/Providers';
|
||||
import type { ChatFormValues } from '~/common';
|
||||
import { ChatContext, AddedChatContext, useFileMapContext, ChatFormProvider } from '~/Providers';
|
||||
import { useChatHelpers, useAddedResponse, useSSE } from '~/hooks';
|
||||
import MessagesView from './Messages/MessagesView';
|
||||
import { useChatHelpers, useSSE } from '~/hooks';
|
||||
import { Spinner } from '~/components/svg';
|
||||
import Presentation from './Presentation';
|
||||
import ChatForm from './Input/ChatForm';
|
||||
|
|
@ -16,8 +18,8 @@ import store from '~/store';
|
|||
|
||||
function ChatView({ index = 0 }: { index?: number }) {
|
||||
const { conversationId } = useParams();
|
||||
const submissionAtIndex = useRecoilValue(store.submissionByIndex(0));
|
||||
useSSE(submissionAtIndex);
|
||||
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
|
||||
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
|
||||
|
||||
const fileMap = useFileMapContext();
|
||||
|
||||
|
|
@ -30,25 +32,37 @@ function ChatView({ index = 0 }: { index?: number }) {
|
|||
});
|
||||
|
||||
const chatHelpers = useChatHelpers(index, conversationId);
|
||||
const addedChatHelpers = useAddedResponse({ rootIndex: index });
|
||||
|
||||
useSSE(rootSubmission, chatHelpers, false);
|
||||
useSSE(addedSubmission, addedChatHelpers, true);
|
||||
|
||||
const methods = useForm<ChatFormValues>({
|
||||
defaultValues: { text: '' },
|
||||
});
|
||||
|
||||
return (
|
||||
<ChatContext.Provider value={chatHelpers}>
|
||||
<Presentation useSidePanel={true}>
|
||||
{isLoading && conversationId !== 'new' ? (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<Spinner className="opacity-0" />
|
||||
</div>
|
||||
) : messagesTree && messagesTree.length !== 0 ? (
|
||||
<MessagesView messagesTree={messagesTree} Header={<Header />} />
|
||||
) : (
|
||||
<Landing Header={<Header />} />
|
||||
)}
|
||||
<div className="w-full border-t-0 pl-0 pt-2 dark:border-white/20 md:w-[calc(100%-.5rem)] md:border-t-0 md:border-transparent md:pl-0 md:pt-0 md:dark:border-transparent">
|
||||
<ChatForm index={index} />
|
||||
<Footer />
|
||||
</div>
|
||||
</Presentation>
|
||||
</ChatContext.Provider>
|
||||
<ChatFormProvider {...methods}>
|
||||
<ChatContext.Provider value={chatHelpers}>
|
||||
<AddedChatContext.Provider value={addedChatHelpers}>
|
||||
<Presentation useSidePanel={true}>
|
||||
{isLoading && conversationId !== 'new' ? (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<Spinner className="opacity-0" />
|
||||
</div>
|
||||
) : messagesTree && messagesTree.length !== 0 ? (
|
||||
<MessagesView messagesTree={messagesTree} Header={<Header />} />
|
||||
) : (
|
||||
<Landing Header={<Header />} />
|
||||
)}
|
||||
<div className="w-full border-t-0 pl-0 pt-2 dark:border-white/20 md:w-[calc(100%-.5rem)] md:border-t-0 md:border-transparent md:pl-0 md:pt-0 md:dark:border-transparent">
|
||||
<ChatForm index={index} />
|
||||
<Footer />
|
||||
</div>
|
||||
</Presentation>
|
||||
</AddedChatContext.Provider>
|
||||
</ChatContext.Provider>
|
||||
</ChatFormProvider>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ export default function Footer({ className }: { className?: string }) {
|
|||
: '[LibreChat ' +
|
||||
Constants.VERSION +
|
||||
'](https://librechat.ai) - ' +
|
||||
localize('com_ui_pay_per_call')
|
||||
localize('com_ui_latest_footer')
|
||||
).split('|');
|
||||
|
||||
const mainContentRender = mainContentParts.map((text, index) => (
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import type { ContextType } from '~/common';
|
|||
import { EndpointsMenu, ModelSpecsMenu, PresetsMenu, HeaderNewChat } from './Menus';
|
||||
import ExportAndShareMenu from './ExportAndShareMenu';
|
||||
import HeaderOptions from './Input/HeaderOptions';
|
||||
import AddMultiConvo from './AddMultiConvo';
|
||||
import { useMediaQuery } from '~/hooks';
|
||||
|
||||
const defaultInterface = getConfigDefaults().interface;
|
||||
|
|
@ -36,6 +37,7 @@ export default function Header() {
|
|||
className="pl-0"
|
||||
/>
|
||||
)}
|
||||
<AddMultiConvo />
|
||||
</div>
|
||||
{!isSmallScreen && (
|
||||
<ExportAndShareMenu isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false} />
|
||||
|
|
|
|||
66
client/src/components/Chat/Input/AddedConvo.tsx
Normal file
66
client/src/components/Chat/Input/AddedConvo.tsx
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import { useMemo } from 'react';
|
||||
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
|
||||
import type { TConversation, TEndpointOption, TPreset } from 'librechat-data-provider';
|
||||
import type { SetterOrUpdater } from 'recoil';
|
||||
import useGetSender from '~/hooks/Conversations/useGetSender';
|
||||
import { EndpointIcon } from '~/components/Endpoints';
|
||||
import { getPresetTitle } from '~/utils';
|
||||
|
||||
export default function AddedConvo({
|
||||
addedConvo,
|
||||
setAddedConvo,
|
||||
}: {
|
||||
addedConvo: TConversation | null;
|
||||
setAddedConvo: SetterOrUpdater<TConversation | null>;
|
||||
}) {
|
||||
const getSender = useGetSender();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const title = useMemo(() => {
|
||||
const sender = getSender(addedConvo as TEndpointOption);
|
||||
const title = getPresetTitle(addedConvo as TPreset);
|
||||
return `+ ${sender}: ${title}`;
|
||||
}, [addedConvo, getSender]);
|
||||
|
||||
if (!addedConvo) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<div className="flex items-start gap-4 py-2.5 pl-3 pr-1.5 text-sm">
|
||||
<span className="mt-0 flex h-6 w-6 flex-shrink-0 items-center justify-center">
|
||||
<div className="icon-md">
|
||||
<EndpointIcon
|
||||
conversation={addedConvo}
|
||||
endpointsConfig={endpointsConfig}
|
||||
containerClassName="shadow-stroke overflow-hidden rounded-full"
|
||||
context="menu-item"
|
||||
size={20}
|
||||
/>
|
||||
</div>
|
||||
</span>
|
||||
<span className="text-token-text-secondary line-clamp-3 flex-1 py-0.5 font-semibold">
|
||||
{title}
|
||||
</span>
|
||||
<button
|
||||
className="text-token-text-secondary flex-shrink-0"
|
||||
type="button"
|
||||
onClick={() => setAddedConvo(null)}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="24"
|
||||
height="24"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
className="icon-lg"
|
||||
>
|
||||
<path
|
||||
fill="currentColor"
|
||||
fillRule="evenodd"
|
||||
d="M7.293 7.293a1 1 0 0 1 1.414 0L12 10.586l3.293-3.293a1 1 0 1 1 1.414 1.414L13.414 12l3.293 3.293a1 1 0 0 1-1.414 1.414L12 13.414l-3.293 3.293a1 1 0 0 1-1.414-1.414L10.586 12 7.293 8.707a1 1 0 0 1 0-1.414"
|
||||
clipRule="evenodd"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import { useEffect } from 'react';
|
||||
import type { UseFormReturn } from 'react-hook-form';
|
||||
import { TooltipProvider, Tooltip, TooltipTrigger, TooltipContent } from '~/components/ui/';
|
||||
import { TooltipProvider, Tooltip, TooltipTrigger, TooltipContent } from '~/components/ui';
|
||||
import { ListeningIcon, Spinner } from '~/components/svg';
|
||||
import { useLocalize, useSpeechToText } from '~/hooks';
|
||||
import { useChatFormContext } from '~/Providers';
|
||||
import { globalAudioId } from '~/common';
|
||||
|
||||
export default function AudioRecorder({
|
||||
|
|
@ -12,7 +12,7 @@ export default function AudioRecorder({
|
|||
disabled,
|
||||
}: {
|
||||
textAreaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
methods: UseFormReturn<{ text: string }>;
|
||||
methods: ReturnType<typeof useChatFormContext>;
|
||||
ask: (data: { text: string }) => void;
|
||||
disabled: boolean;
|
||||
}) {
|
||||
|
|
@ -31,15 +31,26 @@ export default function AudioRecorder({
|
|||
}
|
||||
};
|
||||
|
||||
const { isListening, isLoading, startRecording, stopRecording, speechText, clearText } =
|
||||
useSpeechToText(handleTranscriptionComplete);
|
||||
const {
|
||||
isListening,
|
||||
isLoading,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
interimTranscript,
|
||||
speechText,
|
||||
clearText,
|
||||
} = useSpeechToText(handleTranscriptionComplete);
|
||||
|
||||
useEffect(() => {
|
||||
if (textAreaRef.current) {
|
||||
if (isListening && textAreaRef.current) {
|
||||
methods.setValue('text', interimTranscript, {
|
||||
shouldValidate: true,
|
||||
});
|
||||
} else if (textAreaRef.current) {
|
||||
textAreaRef.current.value = speechText;
|
||||
methods.setValue('text', speechText, { shouldValidate: true });
|
||||
}
|
||||
}, [speechText, methods, textAreaRef]);
|
||||
}, [interimTranscript, speechText, methods, textAreaRef]);
|
||||
|
||||
const handleStartRecording = async () => {
|
||||
await startRecording();
|
||||
|
|
|
|||
|
|
@ -1,18 +1,29 @@
|
|||
import { useForm } from 'react-hook-form';
|
||||
import { memo, useRef, useMemo } from 'react';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { memo, useCallback, useRef, useMemo, useState, useEffect } from 'react';
|
||||
import {
|
||||
supportsFiles,
|
||||
mergeFileConfig,
|
||||
isAssistantsEndpoint,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import { useChatContext, useAssistantsMapContext } from '~/Providers';
|
||||
import { useAutoSave } from '~/hooks/Input/useAutoSave';
|
||||
import { useRequiresKey, useTextarea } from '~/hooks';
|
||||
import {
|
||||
useChatContext,
|
||||
useAddedChatContext,
|
||||
useAssistantsMapContext,
|
||||
useChatFormContext,
|
||||
} from '~/Providers';
|
||||
import {
|
||||
useTextarea,
|
||||
useAutoSave,
|
||||
useRequiresKey,
|
||||
useHandleKeyUp,
|
||||
useSubmitMessage,
|
||||
} from '~/hooks';
|
||||
import { TextareaAutosize } from '~/components/ui';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import TextareaHeader from './TextareaHeader';
|
||||
import PromptsCommand from './PromptsCommand';
|
||||
import AttachFile from './Files/AttachFile';
|
||||
import AudioRecorder from './AudioRecorder';
|
||||
import { mainTextareaId } from '~/common';
|
||||
|
|
@ -26,58 +37,59 @@ import store from '~/store';
|
|||
const ChatForm = ({ index = 0 }) => {
|
||||
const submitButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
const SpeechToText = useRecoilValue(store.SpeechToText);
|
||||
const TextToSpeech = useRecoilValue(store.TextToSpeech);
|
||||
|
||||
const SpeechToText = useRecoilValue(store.speechToText);
|
||||
const TextToSpeech = useRecoilValue(store.textToSpeech);
|
||||
const automaticPlayback = useRecoilValue(store.automaticPlayback);
|
||||
|
||||
const [showStopButton, setShowStopButton] = useRecoilState(store.showStopButtonByIndex(index));
|
||||
const [showPlusPopover, setShowPlusPopover] = useRecoilState(store.showPlusPopoverFamily(index));
|
||||
const [showMentionPopover, setShowMentionPopover] = useRecoilState(
|
||||
store.showMentionPopoverFamily(index),
|
||||
);
|
||||
const { requiresKey } = useRequiresKey();
|
||||
|
||||
const methods = useForm<{ text: string }>({
|
||||
defaultValues: { text: '' },
|
||||
const { requiresKey } = useRequiresKey();
|
||||
const handleKeyUp = useHandleKeyUp({
|
||||
index,
|
||||
textAreaRef,
|
||||
setShowPlusPopover,
|
||||
setShowMentionPopover,
|
||||
});
|
||||
const { handlePaste, handleKeyDown, handleCompositionStart, handleCompositionEnd } = useTextarea({
|
||||
textAreaRef,
|
||||
submitButtonRef,
|
||||
disabled: !!requiresKey,
|
||||
});
|
||||
|
||||
const { handlePaste, handleKeyDown, handleKeyUp, handleCompositionStart, handleCompositionEnd } =
|
||||
useTextarea({
|
||||
textAreaRef,
|
||||
submitButtonRef,
|
||||
disabled: !!requiresKey,
|
||||
});
|
||||
|
||||
const {
|
||||
ask,
|
||||
files,
|
||||
setFiles,
|
||||
conversation,
|
||||
isSubmitting,
|
||||
filesLoading,
|
||||
setFilesLoading,
|
||||
newConversation,
|
||||
handleStopGenerating,
|
||||
} = useChatContext();
|
||||
const methods = useChatFormContext();
|
||||
const {
|
||||
addedIndex,
|
||||
generateConversation,
|
||||
conversation: addedConvo,
|
||||
setConversation: setAddedConvo,
|
||||
isSubmitting: isSubmittingAdded,
|
||||
} = useAddedChatContext();
|
||||
const showStopAdded = useRecoilValue(store.showStopButtonByIndex(addedIndex));
|
||||
|
||||
const { clearDraft } = useAutoSave({
|
||||
conversationId: useMemo(() => conversation?.conversationId, [conversation]),
|
||||
textAreaRef,
|
||||
setValue: methods.setValue,
|
||||
files,
|
||||
setFiles,
|
||||
});
|
||||
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
|
||||
const submitMessage = useCallback(
|
||||
(data?: { text: string }) => {
|
||||
if (!data) {
|
||||
return console.warn('No data provided to submitMessage');
|
||||
}
|
||||
ask({ text: data.text });
|
||||
methods.reset();
|
||||
clearDraft();
|
||||
},
|
||||
[ask, methods, clearDraft],
|
||||
);
|
||||
const { submitMessage, submitPrompt } = useSubmitMessage({ clearDraft });
|
||||
|
||||
const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
const endpoint = endpointType ?? _endpoint;
|
||||
|
|
@ -113,10 +125,26 @@ const ChatForm = ({ index = 0 }) => {
|
|||
>
|
||||
<div className="relative flex h-full flex-1 items-stretch md:flex-col">
|
||||
<div className="flex w-full items-center">
|
||||
{showMentionPopover && (
|
||||
<Mention setShowMentionPopover={setShowMentionPopover} textAreaRef={textAreaRef} />
|
||||
{showPlusPopover && !isAssistantsEndpoint(endpoint) && (
|
||||
<Mention
|
||||
setShowMentionPopover={setShowPlusPopover}
|
||||
newConversation={generateConversation}
|
||||
textAreaRef={textAreaRef}
|
||||
commandChar="+"
|
||||
placeholder="com_ui_add"
|
||||
includeAssistants={false}
|
||||
/>
|
||||
)}
|
||||
{showMentionPopover && (
|
||||
<Mention
|
||||
setShowMentionPopover={setShowMentionPopover}
|
||||
newConversation={newConversation}
|
||||
textAreaRef={textAreaRef}
|
||||
/>
|
||||
)}
|
||||
<PromptsCommand index={index} textAreaRef={textAreaRef} submitPrompt={submitPrompt} />
|
||||
<div className="bg-token-main-surface-primary relative flex w-full flex-grow flex-col overflow-hidden rounded-2xl border dark:border-gray-600 dark:text-white [&:has(textarea:focus)]:border-gray-300 [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)] dark:[&:has(textarea:focus)]:border-gray-500">
|
||||
<TextareaHeader addedConvo={addedConvo} setAddedConvo={setAddedConvo} />
|
||||
<FileRow
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
|
|
@ -162,7 +190,7 @@ const ChatForm = ({ index = 0 }) => {
|
|||
endpointType={endpointType}
|
||||
disabled={disableInputs}
|
||||
/>
|
||||
{isSubmitting && showStopButton ? (
|
||||
{(isSubmitting || isSubmittingAdded) && (showStopButton || showStopAdded) ? (
|
||||
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
|
||||
) : (
|
||||
endpoint && (
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ const AttachFile = ({
|
|||
<button
|
||||
disabled={!!disabled}
|
||||
type="button"
|
||||
tabIndex={1}
|
||||
className="btn relative p-0 text-black dark:text-white"
|
||||
aria-label="Attach files"
|
||||
style={{ padding: 0 }}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ const FileUpload: React.FC<FileUploadProps> = ({
|
|||
<label
|
||||
htmlFor={`file-upload-${id}`}
|
||||
className={cn(
|
||||
'mr-1 flex h-auto cursor-pointer items-center rounded bg-transparent px-2 py-1 text-xs font-medium font-normal transition-colors hover:bg-gray-100 hover:text-green-600 dark:bg-transparent dark:text-gray-300 dark:hover:bg-gray-700 dark:hover:text-green-500',
|
||||
'mr-1 flex h-auto cursor-pointer items-center rounded bg-transparent px-2 py-1 text-xs font-normal transition-colors hover:bg-gray-100 hover:text-green-600 dark:bg-transparent dark:text-gray-300 dark:hover:bg-gray-700 dark:hover:text-green-500',
|
||||
statusColor,
|
||||
containerClassName,
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ const sourceToEndpoint = {
|
|||
[FileSources.azure]: EModelEndpoint.azureOpenAI,
|
||||
};
|
||||
const sourceToClassname = {
|
||||
[FileSources.openai]: 'bg-black/65',
|
||||
[FileSources.openai]: 'bg-white/75 dark:bg-black/65',
|
||||
[FileSources.azure]: 'azure-bg-color opacity-85',
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ export default function DataTable<TData, TValue>({ columns, data }: DataTablePro
|
|||
deleteFiles({ files: filesToDelete as TFile[] });
|
||||
setRowSelection({});
|
||||
}}
|
||||
className="ml-1 gap-2 dark:hover:bg-gray-750/25 sm:ml-0"
|
||||
className="ml-1 gap-2 dark:hover:bg-gray-850/25 sm:ml-0"
|
||||
disabled={!table.getFilteredSelectedRowModel().rows.length || isDeleting}
|
||||
>
|
||||
{isDeleting ? (
|
||||
|
|
@ -121,7 +121,7 @@ export default function DataTable<TData, TValue>({ columns, data }: DataTablePro
|
|||
{/* Filter Menu */}
|
||||
<DropdownMenuContent
|
||||
align="end"
|
||||
className="z-[1001] dark:border-gray-700 dark:bg-gray-750"
|
||||
className="z-[1001] dark:border-gray-700 dark:bg-gray-850"
|
||||
>
|
||||
{table
|
||||
.getAllColumns()
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ export function SortFilterHeader<TData, TValue>({
|
|||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent
|
||||
align="start"
|
||||
className="z-[1001] dark:border-gray-700 dark:bg-gray-750"
|
||||
className="z-[1001] dark:border-gray-700 dark:bg-gray-850"
|
||||
>
|
||||
<DropdownMenuItem
|
||||
onClick={() => column.toggleSorting(false)}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue