diff --git a/client/src/components/Chat/Menus/Models/ModelSpecsMenu.tsx b/client/src/components/Chat/Menus/Models/ModelSpecsMenu.tsx index 400a514dd4..31c69a2259 100644 --- a/client/src/components/Chat/Menus/Models/ModelSpecsMenu.tsx +++ b/client/src/components/Chat/Menus/Models/ModelSpecsMenu.tsx @@ -19,9 +19,22 @@ export default function ModelSpecsMenu({ modelSpecs }: { modelSpecs?: TModelSpec const localize = useLocalize(); const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery(); const modularChat = useRecoilValue(store.modularChat); + const user = useRecoilValue(store.user); const getDefaultConversation = useDefaultConvo(); const assistantMap = useAssistantsMapContext(); + const allowedModelSpecs = useMemo(() => { + if (!modelSpecs) {return [];} + return modelSpecs.filter(spec => { + // If no groups defined for spec, allow it. + if (!spec.groups || spec.groups.length === 0) {return true;} + // Otherwise, check if the user exists and has groups. + if (!user || !user.groups || user.groups.length === 0) {return false;} + // Check if at least one of the spec's groups is in the user's groups. + return spec.groups.some(groupId => user.groups.includes(groupId)); + }); + }, [modelSpecs, user]); + const onSelectSpec = (spec: TModelSpec) => { const { preset } = spec; preset.iconURL = getModelSpecIconURL(spec); @@ -82,21 +95,15 @@ export default function ModelSpecsMenu({ modelSpecs }: { modelSpecs?: TModelSpec }; const selected = useMemo(() => { - const spec = modelSpecs?.find((spec) => spec.name === conversation?.spec); - if (!spec) { - return undefined; - } - return spec; - }, [modelSpecs, conversation?.spec]); + const spec = allowedModelSpecs.find((spec) => spec.name === conversation?.spec); + return spec || undefined; + }, [allowedModelSpecs, conversation?.spec]); const menuRef = useRef(null); const handleKeyDown = useCallback((event: KeyboardEvent) => { const menuItems = menuRef.current?.querySelectorAll('[role="option"]'); - if (!menuItems) { - return; - } - if (!menuItems.length) { + if (!menuItems || !menuItems.length) { return; } @@ -132,7 +139,7 @@ export default function ModelSpecsMenu({ modelSpecs }: { modelSpecs?: TModelSpec endpointsConfig={endpointsConfig} /> - {modelSpecs && modelSpecs.length && ( + {allowedModelSpecs && allowedModelSpecs.length > 0 && (
; // List of group ObjectIds allowed to access this model + // badgeIcon?: string; // URL to badge icon for visual categorization + // badgeTooltip?: string; // Tooltip text for the badge }; export const tModelSpecSchema = z.object({ @@ -32,6 +35,9 @@ export const tModelSpecSchema = z.object({ showIconInHeader: z.boolean().optional(), iconURL: z.union([z.string(), eModelEndpointSchema]).optional(), authType: authTypeSchema.optional(), + groups: z.array(z.string()).optional(), + // badgeIcon: z.string().url('Must be a valid URL').optional(), + // badgeTooltip: z.string().optional(), }); export const specsConfigSchema = z.object({ diff --git a/packages/data-provider/src/types.ts b/packages/data-provider/src/types.ts index 6771901267..12b58d19fb 100644 --- a/packages/data-provider/src/types.ts +++ b/packages/data-provider/src/types.ts @@ -106,6 +106,16 @@ export type TBackupCode = { usedAt: Date | null; }; +export type TGroup = { + id: string; + name: string; + description?: string; + externalId?: string; + provider: 'local' | 'openid'; + createdAt?: string; + updatedAt?: string; +}; + export type TUser = { id: string; username: string; @@ -116,6 +126,7 @@ export type TUser = { provider: string; plugins?: string[]; backupCodes?: TBackupCode[]; + groups: string[]; createdAt: string; updatedAt: string; };