mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-18 01:10:14 +01:00
🤖 feat: Private Assistants (#2881)
* feat: add configuration for user private assistants * filter private assistant message requests * add test for privateAssistants * add privateAssistants configuration to tests * fix: destructuring error when assistants config is not added * chore: revert chat controller changes * chore: add payload type, add metadata types * feat: validateAssistant * refactor(fetchAssistants): allow for flexibility * feat: validateAuthor * refactor: return all assistants to ADMIN role * feat: add assistant doc on assistant creation * refactor(listAssistants): use `listAllAssistants` to exhaustively fetch all assistants * chore: add suggestion to tts error * refactor(validateAuthor): attempt database check first * refactor: author validation when patching/deleting assistant --------- Co-authored-by: Leon Juenemann <leon.juenemann@maibornwolff.de>
This commit is contained in:
parent
9f2538fcd9
commit
5dc5d875ba
20 changed files with 308 additions and 109 deletions
|
|
@ -20,6 +20,7 @@ const {
|
|||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
|
|
@ -31,15 +32,14 @@ const { getModelMaxTokens } = require('~/utils');
|
|||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { handleAbortError } = require('~/server/middleware');
|
||||
|
||||
const ten_minutes = 1000 * 60 * 10;
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {Express.Request} req - The request object, containing the request data.
|
||||
* @param {object} req - The request object, containing the request data.
|
||||
* @param {object} req.body - The request payload.
|
||||
* @param {Express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
|
|
@ -60,30 +60,6 @@ const chatV1 = async (req, res) => {
|
|||
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||
} = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
|
||||
if (assistantsConfig) {
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
let openai;
|
||||
/** @type {string|undefined} - the current thread id */
|
||||
|
|
@ -311,6 +287,7 @@ const chatV1 = async (req, res) => {
|
|||
});
|
||||
|
||||
openai = _openai;
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
if (previousMessages.length) {
|
||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ const {
|
|||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
|
|
@ -30,8 +31,6 @@ const { getModelMaxTokens } = require('~/utils');
|
|||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { handleAbortError } = require('~/server/middleware');
|
||||
|
||||
const ten_minutes = 1000 * 60 * 10;
|
||||
|
||||
/**
|
||||
|
|
@ -60,30 +59,6 @@ const chatV2 = async (req, res) => {
|
|||
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||
} = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
|
||||
if (assistantsConfig) {
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
let openai;
|
||||
/** @type {string|undefined} - the current thread id */
|
||||
|
|
@ -309,6 +284,7 @@ const chatV2 = async (req, res) => {
|
|||
});
|
||||
|
||||
openai = _openai;
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
if (previousMessages.length) {
|
||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
const { EModelEndpoint, CacheKeys, defaultAssistantsVersion } = require('librechat-data-provider');
|
||||
const {
|
||||
EModelEndpoint,
|
||||
CacheKeys,
|
||||
defaultAssistantsVersion,
|
||||
defaultOrderQuery,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initializeClient: initAzureClient,
|
||||
} = require('~/server/services/Endpoints/azureAssistants');
|
||||
|
|
@ -35,6 +40,7 @@ const getCurrentVersion = async (req, endpoint) => {
|
|||
* Initializes the client with the current request and response objects and lists assistants
|
||||
* according to the query parameters. This function abstracts the logic for non-Azure paths.
|
||||
*
|
||||
* @deprecated
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
|
|
@ -43,11 +49,65 @@ const getCurrentVersion = async (req, endpoint) => {
|
|||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const listAssistants = async ({ req, res, version, query }) => {
|
||||
const _listAssistants = async ({ req, res, version, query }) => {
|
||||
const { openai } = await getOpenAIClient({ req, res, version });
|
||||
return openai.beta.assistants.list(query);
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches all assistants based on provided query params, until `has_more` is `false`.
|
||||
*
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {string} params.version - The API version to use.
|
||||
* @param {Omit<AssistantListParams, 'endpoint'>} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const listAllAssistants = async ({ req, res, version, query }) => {
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai } = await getOpenAIClient({ req, res, version });
|
||||
const allAssistants = [];
|
||||
|
||||
let first_id;
|
||||
let last_id;
|
||||
let afterToken = query.after;
|
||||
let hasMore = true;
|
||||
|
||||
while (hasMore) {
|
||||
const response = await openai.beta.assistants.list({
|
||||
...query,
|
||||
after: afterToken,
|
||||
});
|
||||
|
||||
const { body } = response;
|
||||
|
||||
allAssistants.push(...body.data);
|
||||
hasMore = body.has_more;
|
||||
|
||||
if (!first_id) {
|
||||
first_id = body.first_id;
|
||||
}
|
||||
|
||||
if (hasMore) {
|
||||
afterToken = body.last_id;
|
||||
} else {
|
||||
last_id = body.last_id;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
data: allAssistants,
|
||||
body: {
|
||||
data: allAssistants,
|
||||
has_more: false,
|
||||
first_id,
|
||||
last_id,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Asynchronously lists assistants for Azure configured groups.
|
||||
*
|
||||
|
|
@ -82,7 +142,7 @@ const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, que
|
|||
/* The specified model is only necessary to
|
||||
fetch assistants for the shared instance */
|
||||
req.body.model = currentModelTuples[0][0];
|
||||
promises.push(listAssistants({ req, res, version, query }));
|
||||
promises.push(listAllAssistants({ req, res, version, query }));
|
||||
}
|
||||
|
||||
const resolvedQueries = await Promise.all(promises);
|
||||
|
|
@ -133,8 +193,27 @@ async function getOpenAIClient({ req, res, endpointOption, initAppClient, overri
|
|||
return result;
|
||||
}
|
||||
|
||||
const fetchAssistants = async (req, res) => {
|
||||
const { limit = 100, order = 'desc', after, before, endpoint } = req.query;
|
||||
/**
|
||||
* Returns a list of assistants.
|
||||
* @param {object} params
|
||||
* @param {object} params.req - Express Request
|
||||
* @param {AssistantListParams} [params.req.query] - The assistant list parameters for pagination and sorting.
|
||||
* @param {object} params.res - Express Response
|
||||
* @param {string} [params.overrideEndpoint] - The endpoint to override the request endpoint.
|
||||
* @returns {Promise<AssistantListResponse>} 200 - success response - application/json
|
||||
*/
|
||||
const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
||||
const {
|
||||
limit = 100,
|
||||
order = 'desc',
|
||||
after,
|
||||
before,
|
||||
endpoint,
|
||||
} = req.query ?? {
|
||||
endpoint: overrideEndpoint,
|
||||
...defaultOrderQuery,
|
||||
};
|
||||
|
||||
const version = await getCurrentVersion(req, endpoint);
|
||||
const query = { limit, order, after, before };
|
||||
|
||||
|
|
@ -142,15 +221,47 @@ const fetchAssistants = async (req, res) => {
|
|||
let body;
|
||||
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
({ body } = await listAssistants({ req, res, version, query }));
|
||||
({ body } = await listAllAssistants({ req, res, version, query }));
|
||||
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||
}
|
||||
|
||||
if (req.user.role === 'ADMIN') {
|
||||
return body;
|
||||
} else if (!req.app.locals[endpoint]) {
|
||||
return body;
|
||||
}
|
||||
|
||||
body.data = filterAssistants({
|
||||
userId: req.user.id,
|
||||
assistants: body.data,
|
||||
assistantsConfig: req.app.locals[endpoint],
|
||||
});
|
||||
return body;
|
||||
};
|
||||
|
||||
/**
|
||||
* Filter assistants based on configuration.
|
||||
*
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {string} params.userId - The user ID to filter private assistants.
|
||||
* @param {Assistant[]} params.assistants - The list of assistants to filter.
|
||||
* @param {Partial<TAssistantEndpoint>} params.assistantsConfig - The assistant configuration.
|
||||
* @returns {Assistant[]} - The filtered list of assistants.
|
||||
*/
|
||||
function filterAssistants({ assistants, userId, assistantsConfig }) {
|
||||
const { supportedIds, excludedIds, privateAssistants } = assistantsConfig;
|
||||
if (privateAssistants) {
|
||||
return assistants.filter((assistant) => userId === assistant.metadata?.author);
|
||||
} else if (supportedIds?.length) {
|
||||
return assistants.filter((assistant) => supportedIds.includes(assistant.id));
|
||||
} else if (excludedIds?.length) {
|
||||
return assistants.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||
}
|
||||
return assistants;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOpenAIClient,
|
||||
fetchAssistants,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
const { FileContext } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { deleteAssistantActions } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
|
||||
const { uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { updateAssistant, getAssistants } = require('~/models/Assistant');
|
||||
const { getOpenAIClient, fetchAssistants } = require('./helpers');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
|
@ -40,9 +41,11 @@ const createAssistant = async (req, res) => {
|
|||
};
|
||||
|
||||
const assistant = await openai.beta.assistants.create(assistantData);
|
||||
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||
if (azureModelIdentifier) {
|
||||
assistant.model = azureModelIdentifier;
|
||||
}
|
||||
await promise;
|
||||
logger.debug('/assistants/', assistant);
|
||||
res.status(201).json(assistant);
|
||||
} catch (error) {
|
||||
|
|
@ -61,7 +64,6 @@ const retrieveAssistant = async (req, res) => {
|
|||
try {
|
||||
/* NOTE: not actually being used right now */
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
res.json(assistant);
|
||||
|
|
@ -83,6 +85,7 @@ const retrieveAssistant = async (req, res) => {
|
|||
const patchAssistant = async (req, res) => {
|
||||
try {
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const { endpoint: _e, ...updateData } = req.body;
|
||||
|
|
@ -119,6 +122,7 @@ const patchAssistant = async (req, res) => {
|
|||
const deleteAssistant = async (req, res) => {
|
||||
try {
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const deletionStatus = await openai.beta.assistants.del(assistant_id);
|
||||
|
|
@ -141,19 +145,7 @@ const deleteAssistant = async (req, res) => {
|
|||
*/
|
||||
const listAssistants = async (req, res) => {
|
||||
try {
|
||||
const body = await fetchAssistants(req, res);
|
||||
|
||||
if (req.app.locals?.[req.query.endpoint]) {
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals[req.query.endpoint];
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
if (supportedIds?.length) {
|
||||
body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id));
|
||||
} else if (excludedIds?.length) {
|
||||
body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||
}
|
||||
}
|
||||
|
||||
const body = await fetchAssistants({ req, res });
|
||||
res.json(body);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants] Error listing assistants', error);
|
||||
|
|
@ -195,6 +187,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
|||
|
||||
let { metadata: _metadata = '{}' } = req.body;
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const image = await uploadImageBuffer({
|
||||
req,
|
||||
|
|
@ -229,7 +222,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
|||
|
||||
const promises = [];
|
||||
promises.push(
|
||||
updateAssistant(
|
||||
updateAssistantDoc(
|
||||
{ assistant_id },
|
||||
{
|
||||
avatar: {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
const { ToolCallTypes } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { validateAndUpdateTool } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc } = require('~/models/Assistant');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -37,9 +39,11 @@ const createAssistant = async (req, res) => {
|
|||
};
|
||||
|
||||
const assistant = await openai.beta.assistants.create(assistantData);
|
||||
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||
if (azureModelIdentifier) {
|
||||
assistant.model = azureModelIdentifier;
|
||||
}
|
||||
await promise;
|
||||
logger.debug('/assistants/', assistant);
|
||||
res.status(201).json(assistant);
|
||||
} catch (error) {
|
||||
|
|
@ -58,6 +62,7 @@ const createAssistant = async (req, res) => {
|
|||
* @returns {Promise<Assistant>} The updated assistant.
|
||||
*/
|
||||
const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
||||
await validateAuthor({ req, openai });
|
||||
const tools = [];
|
||||
|
||||
let hasFileSearch = false;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue