mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-21 21:50:49 +02:00
feat: Implement paginated access to prompt groups with filtering and public visibility
This commit is contained in:
parent
1b14721e75
commit
a928115a84
6 changed files with 382 additions and 30 deletions
|
@ -247,10 +247,130 @@ const deletePromptGroup = async ({ _id, author, role }) => {
|
|||
return { message: 'Prompt group deleted successfully' };
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt groups by accessible IDs with optional cursor-based pagination.
|
||||
* @param {Object} params - The parameters for getting accessible prompt groups.
|
||||
* @param {Array} [params.accessibleIds] - Array of prompt group ObjectIds the user has ACL access to.
|
||||
* @param {Object} [params.otherParams] - Additional query parameters (including author filter).
|
||||
* @param {number} [params.limit] - Number of prompt groups to return (max 100). If not provided, returns all prompt groups.
|
||||
* @param {string} [params.after] - Cursor for pagination - get prompt groups after this cursor. // base64 encoded JSON string with updatedAt and _id.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the prompt groups data and pagination info.
|
||||
*/
|
||||
async function getListPromptGroupsByAccess({
|
||||
accessibleIds = [],
|
||||
otherParams = {},
|
||||
limit = null,
|
||||
after = null,
|
||||
}) {
|
||||
const isPaginated = limit !== null && limit !== undefined;
|
||||
const normalizedLimit = isPaginated ? Math.min(Math.max(1, parseInt(limit) || 20), 100) : null;
|
||||
|
||||
// Build base query combining ACL accessible prompt groups with other filters
|
||||
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
|
||||
|
||||
// Add cursor condition
|
||||
if (after) {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
const cursorCondition = {
|
||||
$or: [
|
||||
{ updatedAt: { $lt: new Date(updatedAt) } },
|
||||
{ updatedAt: new Date(updatedAt), _id: { $gt: new ObjectId(_id) } },
|
||||
],
|
||||
};
|
||||
|
||||
// Merge cursor condition with base query
|
||||
if (Object.keys(baseQuery).length > 0) {
|
||||
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
|
||||
// Remove the original conditions from baseQuery to avoid duplication
|
||||
Object.keys(baseQuery).forEach((key) => {
|
||||
if (key !== '$and') delete baseQuery[key];
|
||||
});
|
||||
} else {
|
||||
Object.assign(baseQuery, cursorCondition);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Invalid cursor:', error.message);
|
||||
}
|
||||
}
|
||||
|
||||
// Build aggregation pipeline
|
||||
const pipeline = [{ $match: baseQuery }, { $sort: { updatedAt: -1, _id: 1 } }];
|
||||
|
||||
// Only apply limit if pagination is requested
|
||||
if (isPaginated) {
|
||||
pipeline.push({ $limit: normalizedLimit + 1 });
|
||||
}
|
||||
|
||||
// Add lookup for production prompt
|
||||
pipeline.push(
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project: {
|
||||
name: 1,
|
||||
numberOfGenerations: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
projectIds: 1,
|
||||
productionId: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const promptGroups = await PromptGroup.aggregate(pipeline).exec();
|
||||
|
||||
const hasMore = isPaginated ? promptGroups.length > normalizedLimit : false;
|
||||
const data = (isPaginated ? promptGroups.slice(0, normalizedLimit) : promptGroups).map(
|
||||
(group) => {
|
||||
if (group.author) {
|
||||
group.author = group.author.toString();
|
||||
}
|
||||
return group;
|
||||
},
|
||||
);
|
||||
|
||||
// Generate next cursor only if paginated
|
||||
let nextCursor = null;
|
||||
if (isPaginated && hasMore && data.length > 0) {
|
||||
const lastGroup = promptGroups[normalizedLimit - 1];
|
||||
nextCursor = Buffer.from(
|
||||
JSON.stringify({
|
||||
updatedAt: lastGroup.updatedAt.toISOString(),
|
||||
_id: lastGroup._id.toString(),
|
||||
}),
|
||||
).toString('base64');
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list',
|
||||
data,
|
||||
first_id: data.length > 0 ? data[0]._id.toString() : null,
|
||||
last_id: data.length > 0 ? data[data.length - 1]._id.toString() : null,
|
||||
has_more: hasMore,
|
||||
after: nextCursor,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getPromptGroups,
|
||||
deletePromptGroup,
|
||||
getAllPromptGroups,
|
||||
getListPromptGroupsByAccess,
|
||||
/**
|
||||
* Create a prompt and its respective group
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
const express = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const {
|
||||
generateCheckAccess,
|
||||
markPublicPromptGroups,
|
||||
buildPromptGroupFilter,
|
||||
formatPromptGroupsResponse,
|
||||
createEmptyPromptGroupsResponse,
|
||||
filterAccessibleIdsBySharedLogic,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Permissions,
|
||||
SystemRoles,
|
||||
|
@ -11,12 +18,11 @@ const {
|
|||
PermissionTypes,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getListPromptGroupsByAccess,
|
||||
makePromptProduction,
|
||||
getAllPromptGroups,
|
||||
updatePromptGroup,
|
||||
deletePromptGroup,
|
||||
createPromptGroup,
|
||||
getPromptGroups,
|
||||
getPromptGroup,
|
||||
deletePrompt,
|
||||
getPrompts,
|
||||
|
@ -95,23 +101,48 @@ router.get(
|
|||
router.get('/all', async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { name, category, ...otherFilters } = req.query;
|
||||
const { filter, searchShared, searchSharedOnly } = buildPromptGroupFilter({
|
||||
name,
|
||||
category,
|
||||
...otherFilters,
|
||||
});
|
||||
|
||||
// Get promptGroup IDs the user has VIEW access to via ACL
|
||||
const accessibleIds = await findAccessibleResources({
|
||||
let accessibleIds = await findAccessibleResources({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const groups = await getAllPromptGroups(req, {});
|
||||
const publiclyAccessibleIds = await findPubliclyAccessibleResources({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
// Filter the results to only include accessible groups
|
||||
const accessibleGroups = groups.filter((group) =>
|
||||
accessibleIds.some((id) => id.toString() === group._id.toString()),
|
||||
);
|
||||
const filteredAccessibleIds = await filterAccessibleIdsBySharedLogic({
|
||||
accessibleIds,
|
||||
searchShared,
|
||||
searchSharedOnly,
|
||||
publicPromptGroupIds: publiclyAccessibleIds,
|
||||
});
|
||||
|
||||
res.status(200).send(accessibleGroups);
|
||||
const result = await getListPromptGroupsByAccess({
|
||||
accessibleIds: filteredAccessibleIds,
|
||||
otherParams: filter,
|
||||
});
|
||||
|
||||
if (!result) {
|
||||
return res.status(200).send([]);
|
||||
}
|
||||
|
||||
const { data: promptGroups = [] } = result;
|
||||
if (!promptGroups.length) {
|
||||
return res.status(200).send([]);
|
||||
}
|
||||
|
||||
const groupsWithPublicFlag = markPublicPromptGroups(promptGroups, publiclyAccessibleIds);
|
||||
res.status(200).send(groupsWithPublicFlag);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||
|
@ -125,40 +156,66 @@ router.get('/all', async (req, res) => {
|
|||
router.get('/groups', async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const filter = { ...req.query };
|
||||
delete filter.author; // Remove author filter as we'll use ACL
|
||||
const { pageSize, pageNumber, limit, cursor, name, category, ...otherFilters } = req.query;
|
||||
|
||||
// Get promptGroup IDs the user has VIEW access to via ACL
|
||||
const accessibleIds = await findAccessibleResources({
|
||||
const { filter, searchShared, searchSharedOnly } = buildPromptGroupFilter({
|
||||
name,
|
||||
category,
|
||||
...otherFilters,
|
||||
});
|
||||
|
||||
let actualLimit = limit;
|
||||
let actualCursor = cursor;
|
||||
|
||||
if (pageSize && !limit) {
|
||||
actualLimit = parseInt(pageSize, 10);
|
||||
}
|
||||
|
||||
let accessibleIds = await findAccessibleResources({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
// Get publicly accessible promptGroups
|
||||
const publiclyAccessibleIds = await findPubliclyAccessibleResources({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const groups = await getPromptGroups(req, filter);
|
||||
const filteredAccessibleIds = await filterAccessibleIdsBySharedLogic({
|
||||
accessibleIds,
|
||||
searchShared,
|
||||
searchSharedOnly,
|
||||
publicPromptGroupIds: publiclyAccessibleIds,
|
||||
});
|
||||
|
||||
if (groups.promptGroups && groups.promptGroups.length > 0) {
|
||||
groups.promptGroups = groups.promptGroups.filter((group) =>
|
||||
accessibleIds.some((id) => id.toString() === group._id.toString()),
|
||||
);
|
||||
const result = await getListPromptGroupsByAccess({
|
||||
accessibleIds: filteredAccessibleIds,
|
||||
otherParams: filter,
|
||||
limit: actualLimit,
|
||||
after: actualCursor,
|
||||
});
|
||||
|
||||
// Mark public groups
|
||||
groups.promptGroups = groups.promptGroups.map((group) => {
|
||||
if (publiclyAccessibleIds.some((id) => id.equals(group._id))) {
|
||||
group.isPublic = true;
|
||||
}
|
||||
return group;
|
||||
});
|
||||
if (!result) {
|
||||
const emptyResponse = createEmptyPromptGroupsResponse({ pageNumber, pageSize, actualLimit });
|
||||
return res.status(200).send(emptyResponse);
|
||||
}
|
||||
|
||||
res.status(200).send(groups);
|
||||
const { data: promptGroups = [], has_more = false, after = null } = result;
|
||||
|
||||
const groupsWithPublicFlag = markPublicPromptGroups(promptGroups, publiclyAccessibleIds);
|
||||
|
||||
const response = formatPromptGroupsResponse({
|
||||
promptGroups: groupsWithPublicFlag,
|
||||
pageNumber,
|
||||
pageSize,
|
||||
actualLimit,
|
||||
hasMore: has_more,
|
||||
after,
|
||||
});
|
||||
|
||||
res.status(200).send(response);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||
|
@ -188,7 +245,6 @@ const createNewPromptGroup = async (req, res) => {
|
|||
|
||||
const result = await createPromptGroup(saveData);
|
||||
|
||||
// Grant owner permissions to the creator on the new promptGroup
|
||||
if (result.prompt && result.prompt._id && result.prompt.groupId) {
|
||||
try {
|
||||
await grantPermission({
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
export * from './content';
|
||||
export * from './prompts';
|
||||
|
|
150
packages/api/src/format/prompts.ts
Normal file
150
packages/api/src/format/prompts.ts
Normal file
|
@ -0,0 +1,150 @@
|
|||
import { SystemCategories } from 'librechat-data-provider';
|
||||
import type { IPromptGroupDocument as IPromptGroup } from '@librechat/data-schemas';
|
||||
import type { Types } from 'mongoose';
|
||||
import type { PromptGroupsListResponse } from '~/types';
|
||||
|
||||
/**
|
||||
* Formats prompt groups for the paginated /groups endpoint response
|
||||
*/
|
||||
export function formatPromptGroupsResponse({
|
||||
promptGroups = [],
|
||||
pageNumber,
|
||||
pageSize,
|
||||
actualLimit,
|
||||
hasMore = false,
|
||||
after = null,
|
||||
}: {
|
||||
promptGroups: IPromptGroup[];
|
||||
pageNumber?: string;
|
||||
pageSize?: string;
|
||||
actualLimit?: string | number;
|
||||
hasMore?: boolean;
|
||||
after?: string | null;
|
||||
}): PromptGroupsListResponse {
|
||||
const effectivePageSize = parseInt(pageSize || '') || parseInt(String(actualLimit || '')) || 10;
|
||||
const totalPages =
|
||||
promptGroups.length > 0 ? Math.ceil(promptGroups.length / effectivePageSize).toString() : '0';
|
||||
|
||||
return {
|
||||
promptGroups,
|
||||
pageNumber: pageNumber || '1',
|
||||
pageSize: pageSize || String(actualLimit) || '10',
|
||||
pages: totalPages,
|
||||
has_more: hasMore,
|
||||
after,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an empty response for the paginated /groups endpoint
|
||||
*/
|
||||
export function createEmptyPromptGroupsResponse({
|
||||
pageNumber,
|
||||
pageSize,
|
||||
actualLimit,
|
||||
}: {
|
||||
pageNumber?: string;
|
||||
pageSize?: string;
|
||||
actualLimit?: string | number;
|
||||
}): PromptGroupsListResponse {
|
||||
return {
|
||||
promptGroups: [],
|
||||
pageNumber: pageNumber || '1',
|
||||
pageSize: pageSize || String(actualLimit) || '10',
|
||||
pages: '0',
|
||||
has_more: false,
|
||||
after: null,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks prompt groups as public based on the publicly accessible IDs
|
||||
*/
|
||||
export function markPublicPromptGroups(
|
||||
promptGroups: IPromptGroup[],
|
||||
publiclyAccessibleIds: Types.ObjectId[],
|
||||
): IPromptGroup[] {
|
||||
if (!promptGroups.length) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return promptGroups.map((group) => {
|
||||
const isPublic = publiclyAccessibleIds.some((id) => id.equals(group._id?.toString()));
|
||||
return isPublic ? ({ ...group, isPublic: true } as IPromptGroup) : group;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds filter object for prompt group queries
|
||||
*/
|
||||
export function buildPromptGroupFilter({
|
||||
name,
|
||||
category,
|
||||
...otherFilters
|
||||
}: {
|
||||
name?: string;
|
||||
category?: string;
|
||||
[key: string]: string | number | boolean | RegExp | undefined;
|
||||
}): {
|
||||
filter: Record<string, string | number | boolean | RegExp | undefined>;
|
||||
searchShared: boolean;
|
||||
searchSharedOnly: boolean;
|
||||
} {
|
||||
const filter: Record<string, string | number | boolean | RegExp | undefined> = {
|
||||
...otherFilters,
|
||||
};
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
|
||||
// Handle name filter - convert to regex for case-insensitive search
|
||||
if (name) {
|
||||
const escapeRegExp = (str: string) => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
filter.name = new RegExp(escapeRegExp(name), 'i');
|
||||
}
|
||||
|
||||
// Handle category filters with special system categories
|
||||
if (category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
} else if (category === SystemCategories.NO_CATEGORY) {
|
||||
filter.category = '';
|
||||
} else if (category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
} else if (category) {
|
||||
filter.category = category;
|
||||
}
|
||||
|
||||
return { filter, searchShared, searchSharedOnly };
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters accessible IDs based on shared/public prompts logic
|
||||
*/
|
||||
export async function filterAccessibleIdsBySharedLogic({
|
||||
accessibleIds,
|
||||
searchShared,
|
||||
searchSharedOnly,
|
||||
publicPromptGroupIds,
|
||||
}: {
|
||||
accessibleIds: Types.ObjectId[];
|
||||
searchShared: boolean;
|
||||
searchSharedOnly: boolean;
|
||||
publicPromptGroupIds?: Types.ObjectId[];
|
||||
}): Promise<Types.ObjectId[]> {
|
||||
const publicIdStrings = new Set((publicPromptGroupIds || []).map((id) => id.toString()));
|
||||
|
||||
if (!searchShared) {
|
||||
// For MY_PROMPTS - exclude public prompts to show only user's own prompts
|
||||
return accessibleIds.filter((id) => !publicIdStrings.has(id.toString()));
|
||||
}
|
||||
|
||||
if (searchSharedOnly) {
|
||||
// Handle SHARED_PROMPTS filter - only return public prompts that user has access to
|
||||
if (!publicPromptGroupIds?.length) {
|
||||
return [];
|
||||
}
|
||||
const accessibleIdStrings = new Set(accessibleIds.map((id) => id.toString()));
|
||||
return publicPromptGroupIds.filter((id) => accessibleIdStrings.has(id.toString()));
|
||||
}
|
||||
|
||||
return [...accessibleIds, ...(publicPromptGroupIds || [])];
|
||||
}
|
|
@ -4,5 +4,6 @@ export * from './error';
|
|||
export * from './google';
|
||||
export * from './mistral';
|
||||
export * from './openai';
|
||||
export * from './prompts';
|
||||
export * from './run';
|
||||
export * from './zod';
|
||||
|
|
24
packages/api/src/types/prompts.ts
Normal file
24
packages/api/src/types/prompts.ts
Normal file
|
@ -0,0 +1,24 @@
|
|||
import type { IPromptGroup as IPromptGroup } from '@librechat/data-schemas';
|
||||
import type { Types } from 'mongoose';
|
||||
|
||||
export interface PromptGroupsListResponse {
|
||||
promptGroups: IPromptGroup[];
|
||||
pageNumber: string;
|
||||
pageSize: string;
|
||||
pages: string;
|
||||
has_more: boolean;
|
||||
after: string | null;
|
||||
}
|
||||
|
||||
export interface PromptGroupsAllResponse {
|
||||
data: IPromptGroup[];
|
||||
}
|
||||
|
||||
export interface AccessiblePromptGroupsResult {
|
||||
object: 'list';
|
||||
data: IPromptGroup[];
|
||||
first_id: Types.ObjectId | null;
|
||||
last_id: Types.ObjectId | null;
|
||||
has_more: boolean;
|
||||
after: string | null;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue