refactor: Integrate Capabilities into Agent File Uploads and Tool Handling (#5048)

* refactor: support drag/drop files for agents, handle undefined tool_resource edge cases

* refactor: consolidate endpoints config logic to dedicated getter

* refactor: Enhance agent tools loading logic to respect capabilities and filter tools accordingly

* refactor: Integrate endpoint capabilities into file upload dropdown for dynamic resource handling

* refactor: Implement capability checks for agent file upload operations

* fix: non-image tool_resource check
This commit is contained in:
Danny Avila 2024-12-19 13:04:48 -05:00 committed by GitHub
parent d68c874db4
commit 3fbbcb1cfe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 449 additions and 189 deletions

View file

@ -228,6 +228,7 @@ const loadTools = async ({
const toolContextMap = {}; const toolContextMap = {};
const remainingTools = []; const remainingTools = [];
const appTools = options.req?.app?.locals?.availableTools ?? {};
for (const tool of tools) { for (const tool of tools) {
if (tool === Tools.execute_code) { if (tool === Tools.execute_code) {
@ -259,7 +260,7 @@ const loadTools = async ({
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id }); return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
}; };
continue; continue;
} else if (mcpToolPattern.test(tool)) { } else if (tool && appTools[tool] && mcpToolPattern.test(tool)) {
requestedTools[tool] = async () => requestedTools[tool] = async () =>
createMCPTool({ createMCPTool({
req: options.req, req: options.req,

View file

@ -1,69 +1,7 @@
const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider'); const { getEndpointsConfig } = require('~/server/services/Config');
const { loadDefaultEndpointsConfig, loadConfigEndpoints } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
async function endpointController(req, res) { async function endpointController(req, res) {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const endpointsConfig = await getEndpointsConfig(req);
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
if (cachedEndpointsConfig) {
res.send(cachedEndpointsConfig);
return;
}
const defaultEndpointsConfig = await loadDefaultEndpointsConfig(req);
const customConfigEndpoints = await loadConfigEndpoints(req);
/** @type {TEndpointsConfig} */
const mergedConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints };
if (mergedConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) {
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
req.app.locals[EModelEndpoint.assistants];
mergedConfig[EModelEndpoint.assistants] = {
...mergedConfig[EModelEndpoint.assistants],
version,
retrievalModels,
disableBuilder,
capabilities,
};
}
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];
mergedConfig[EModelEndpoint.agents] = {
...mergedConfig[EModelEndpoint.agents],
disableBuilder,
capabilities,
};
}
if (
mergedConfig[EModelEndpoint.azureAssistants] &&
req.app.locals?.[EModelEndpoint.azureAssistants]
) {
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
req.app.locals[EModelEndpoint.azureAssistants];
mergedConfig[EModelEndpoint.azureAssistants] = {
...mergedConfig[EModelEndpoint.azureAssistants],
version,
retrievalModels,
disableBuilder,
capabilities,
};
}
if (mergedConfig[EModelEndpoint.bedrock] && req.app.locals?.[EModelEndpoint.bedrock]) {
const { availableRegions } = req.app.locals[EModelEndpoint.bedrock];
mergedConfig[EModelEndpoint.bedrock] = {
...mergedConfig[EModelEndpoint.bedrock],
availableRegions,
};
}
const endpointsConfig = orderEndpointsConfig(mergedConfig);
await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);
res.send(JSON.stringify(endpointsConfig)); res.send(JSON.stringify(endpointsConfig));
} }

View file

@ -1,5 +1,4 @@
const { const {
CacheKeys,
SystemRoles, SystemRoles,
EModelEndpoint, EModelEndpoint,
defaultOrderQuery, defaultOrderQuery,
@ -9,7 +8,7 @@ const {
initializeClient: initAzureClient, initializeClient: initAzureClient,
} = require('~/server/services/Endpoints/azureAssistants'); } = require('~/server/services/Endpoints/azureAssistants');
const { initializeClient } = require('~/server/services/Endpoints/assistants'); const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { getLogStores } = require('~/cache'); const { getEndpointsConfig } = require('~/server/services/Config');
/** /**
* @param {Express.Request} req * @param {Express.Request} req
@ -23,11 +22,8 @@ const getCurrentVersion = async (req, endpoint) => {
version = `v${req.body.version}`; version = `v${req.body.version}`;
} }
if (!version && endpoint) { if (!version && endpoint) {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const endpointsConfig = await getEndpointsConfig(req);
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); version = `v${endpointsConfig?.[endpoint]?.version ?? defaultAssistantsVersion[endpoint]}`;
version = `v${
cachedEndpointsConfig?.[endpoint]?.version ?? defaultAssistantsVersion[endpoint]
}`;
} }
if (!version?.startsWith('v') && version.length !== 2) { if (!version?.startsWith('v') && version.length !== 2) {
throw new Error(`[${req.baseUrl}] Invalid version: ${version}`); throw new Error(`[${req.baseUrl}] Invalid version: ${version}`);

View file

@ -0,0 +1,75 @@
const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider');
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
const loadConfigEndpoints = require('./loadConfigEndpoints');
const getLogStores = require('~/cache/getLogStores');
/**
*
* @param {ServerRequest} req
* @returns {Promise<TEndpointsConfig>}
*/
async function getEndpointsConfig(req) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
if (cachedEndpointsConfig) {
return cachedEndpointsConfig;
}
const defaultEndpointsConfig = await loadDefaultEndpointsConfig(req);
const customConfigEndpoints = await loadConfigEndpoints(req);
/** @type {TEndpointsConfig} */
const mergedConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints };
if (mergedConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) {
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
req.app.locals[EModelEndpoint.assistants];
mergedConfig[EModelEndpoint.assistants] = {
...mergedConfig[EModelEndpoint.assistants],
version,
retrievalModels,
disableBuilder,
capabilities,
};
}
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];
mergedConfig[EModelEndpoint.agents] = {
...mergedConfig[EModelEndpoint.agents],
disableBuilder,
capabilities,
};
}
if (
mergedConfig[EModelEndpoint.azureAssistants] &&
req.app.locals?.[EModelEndpoint.azureAssistants]
) {
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
req.app.locals[EModelEndpoint.azureAssistants];
mergedConfig[EModelEndpoint.azureAssistants] = {
...mergedConfig[EModelEndpoint.azureAssistants],
version,
retrievalModels,
disableBuilder,
capabilities,
};
}
if (mergedConfig[EModelEndpoint.bedrock] && req.app.locals?.[EModelEndpoint.bedrock]) {
const { availableRegions } = req.app.locals[EModelEndpoint.bedrock];
mergedConfig[EModelEndpoint.bedrock] = {
...mergedConfig[EModelEndpoint.bedrock],
availableRegions,
};
}
const endpointsConfig = orderEndpointsConfig(mergedConfig);
await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);
return endpointsConfig;
}
module.exports = { getEndpointsConfig };

View file

@ -3,10 +3,9 @@ const getCustomConfig = require('./getCustomConfig');
const loadCustomConfig = require('./loadCustomConfig'); const loadCustomConfig = require('./loadCustomConfig');
const loadConfigModels = require('./loadConfigModels'); const loadConfigModels = require('./loadConfigModels');
const loadDefaultModels = require('./loadDefaultModels'); const loadDefaultModels = require('./loadDefaultModels');
const getEndpointsConfig = require('./getEndpointsConfig');
const loadOverrideConfig = require('./loadOverrideConfig'); const loadOverrideConfig = require('./loadOverrideConfig');
const loadAsyncEndpoints = require('./loadAsyncEndpoints'); const loadAsyncEndpoints = require('./loadAsyncEndpoints');
const loadConfigEndpoints = require('./loadConfigEndpoints');
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
module.exports = { module.exports = {
config, config,
@ -16,6 +15,5 @@ module.exports = {
loadOverrideConfig, loadOverrideConfig,
loadAsyncEndpoints, loadAsyncEndpoints,
...getCustomConfig, ...getCustomConfig,
loadConfigEndpoints, ...getEndpointsConfig,
loadDefaultEndpointsConfig,
}; };

View file

@ -12,6 +12,7 @@ const {
EToolResources, EToolResources,
mergeFileConfig, mergeFileConfig,
hostImageIdSuffix, hostImageIdSuffix,
AgentCapabilities,
checkOpenAIStorage, checkOpenAIStorage,
removeNullishValues, removeNullishValues,
hostImageNamePrefix, hostImageNamePrefix,
@ -27,6 +28,7 @@ const { addResourceFileId, deleteResourceFileId } = require('~/server/controller
const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent'); const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { createFile, updateFileUsage, deleteFiles } = require('~/models/File'); const { createFile, updateFileUsage, deleteFiles } = require('~/models/File');
const { getEndpointsConfig } = require('~/server/services/Config');
const { loadAuthValues } = require('~/app/clients/tools/util'); const { loadAuthValues } = require('~/app/clients/tools/util');
const { LB_QueueAsyncCall } = require('~/server/utils/queue'); const { LB_QueueAsyncCall } = require('~/server/utils/queue');
const { getStrategyFunctions } = require('./strategies'); const { getStrategyFunctions } = require('./strategies');
@ -451,6 +453,17 @@ const processFileUpload = async ({ req, res, metadata }) => {
res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); res.status(200).json({ message: 'File uploaded and processed successfully', ...result });
}; };
/**
* @param {ServerRequest} req
* @param {AgentCapabilities} capability
* @returns {Promise<boolean>}
*/
const checkCapability = async (req, capability) => {
const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
return capabilities.includes(capability);
};
/** /**
* Applies the current strategy for file uploads. * Applies the current strategy for file uploads.
* Saves file metadata to the database with an expiry TTL. * Saves file metadata to the database with an expiry TTL.
@ -478,9 +491,20 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
throw new Error('No agent ID provided for agent file upload'); throw new Error('No agent ID provided for agent file upload');
} }
const isImage = file.mimetype.startsWith('image');
if (!isImage && !tool_resource) {
/** Note: this needs to be removed when we can support files to providers */
throw new Error('No tool resource provided for non-image agent file upload');
}
let fileInfoMetadata; let fileInfoMetadata;
const entity_id = messageAttachment === true ? undefined : agent_id; const entity_id = messageAttachment === true ? undefined : agent_id;
if (tool_resource === EToolResources.execute_code) { if (tool_resource === EToolResources.execute_code) {
const isCodeEnabled = await checkCapability(req, AgentCapabilities.execute_code);
if (!isCodeEnabled) {
throw new Error('Code execution is not enabled for Agents');
}
const { handleFileUpload: uploadCodeEnvFile } = getStrategyFunctions(FileSources.execute_code); const { handleFileUpload: uploadCodeEnvFile } = getStrategyFunctions(FileSources.execute_code);
const result = await loadAuthValues({ userId: req.user.id, authFields: [EnvVar.CODE_API_KEY] }); const result = await loadAuthValues({ userId: req.user.id, authFields: [EnvVar.CODE_API_KEY] });
const stream = fs.createReadStream(file.path); const stream = fs.createReadStream(file.path);
@ -492,6 +516,11 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
entity_id, entity_id,
}); });
fileInfoMetadata = { fileIdentifier }; fileInfoMetadata = { fileIdentifier };
} else if (tool_resource === EToolResources.file_search) {
const isFileSearchEnabled = await checkCapability(req, AgentCapabilities.file_search);
if (!isFileSearchEnabled) {
throw new Error('File search is not enabled for Agents');
}
} }
const source = const source =
@ -527,7 +556,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
}); });
} }
if (file.mimetype.startsWith('image')) { if (isImage) {
const result = await processImageFile({ const result = await processImageFile({
req, req,
file, file,

View file

@ -8,13 +8,16 @@ const {
ErrorTypes, ErrorTypes,
ContentTypes, ContentTypes,
imageGenTools, imageGenTools,
EModelEndpoint,
actionDelimiter, actionDelimiter,
ImageVisionTool, ImageVisionTool,
openapiToFunction, openapiToFunction,
AgentCapabilities,
validateAndParseOpenAPISpec, validateAndParseOpenAPISpec,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { loadActionSets, createActionTool, domainParser } = require('./ActionService'); const { loadActionSets, createActionTool, domainParser } = require('./ActionService');
const { getEndpointsConfig } = require('~/server/services/Config');
const { recordUsage } = require('~/server/services/Threads'); const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util'); const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers'); const { redactMessage } = require('~/config/parsers');
@ -383,11 +386,37 @@ async function loadAgentTools({ req, agent, tool_resources, openAIApiKey }) {
if (!agent.tools || agent.tools.length === 0) { if (!agent.tools || agent.tools.length === 0) {
return {}; return {};
} }
const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
const areToolsEnabled = capabilities.includes(AgentCapabilities.tools);
if (!areToolsEnabled) {
logger.debug('Tools are not enabled for this agent.');
return {};
}
const isFileSearchEnabled = capabilities.includes(AgentCapabilities.file_search);
const isCodeEnabled = capabilities.includes(AgentCapabilities.execute_code);
const areActionsEnabled = capabilities.includes(AgentCapabilities.actions);
const _agentTools = agent.tools?.filter((tool) => {
if (tool === Tools.file_search && !isFileSearchEnabled) {
return false;
} else if (tool === Tools.execute_code && !isCodeEnabled) {
return false;
}
return true;
});
if (!_agentTools || _agentTools.length === 0) {
return {};
}
const { loadedTools, toolContextMap } = await loadTools({ const { loadedTools, toolContextMap } = await loadTools({
agent, agent,
functions: true, functions: true,
user: req.user.id, user: req.user.id,
tools: agent.tools, tools: _agentTools,
options: { options: {
req, req,
openAIApiKey, openAIApiKey,
@ -434,62 +463,74 @@ async function loadAgentTools({ req, agent, tool_resources, openAIApiKey }) {
return map; return map;
}, {}); }, {});
if (!areActionsEnabled) {
return {
tools: agentTools,
toolContextMap,
};
}
let actionSets = []; let actionSets = [];
const ActionToolMap = {}; const ActionToolMap = {};
for (const toolName of agent.tools) { for (const toolName of _agentTools) {
if (!ToolMap[toolName]) { if (ToolMap[toolName]) {
if (!actionSets.length) { continue;
actionSets = (await loadActionSets({ agent_id: agent.id })) ?? []; }
}
let actionSet = null; if (!actionSets.length) {
let currentDomain = ''; actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
for (let action of actionSets) { }
const domain = await domainParser(req, action.metadata.domain, true);
if (toolName.includes(domain)) {
currentDomain = domain;
actionSet = action;
break;
}
}
if (actionSet) { let actionSet = null;
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec); let currentDomain = '';
if (validationResult.spec) { for (let action of actionSets) {
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction( const domain = await domainParser(req, action.metadata.domain, true);
validationResult.spec, if (toolName.includes(domain)) {
true, currentDomain = domain;
actionSet = action;
break;
}
}
if (!actionSet) {
continue;
}
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
if (validationResult.spec) {
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
validationResult.spec,
true,
);
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
const requestBuilder = requestBuilders[functionName];
const zodSchema = zodSchemas[functionName];
if (requestBuilder) {
const tool = await createActionTool({
action: actionSet,
requestBuilder,
zodSchema,
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
); );
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, ''); throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
const requestBuilder = requestBuilders[functionName];
const zodSchema = zodSchemas[functionName];
if (requestBuilder) {
const tool = await createActionTool({
action: actionSet,
requestBuilder,
zodSchema,
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
}
} }
agentTools.push(tool);
ActionToolMap[toolName] = tool;
} }
} }
} }
if (agent.tools.length > 0 && agentTools.length === 0) { if (_agentTools.length > 0 && agentTools.length === 0) {
throw new Error('No tools found for the specified tool calls.'); logger.warn(`No tools found for the specified tool calls: ${_agentTools.join(', ')}`);
return {};
} }
return { return {

View file

@ -464,6 +464,7 @@ export interface ExtendedFile {
source?: FileSources; source?: FileSources;
attached?: boolean; attached?: boolean;
embedded?: boolean; embedded?: boolean;
tool_resource?: string;
} }
export type ContextType = { navVisible: boolean; setNavVisible: (visible: boolean) => void }; export type ContextType = { navVisible: boolean; setNavVisible: (visible: boolean) => void };

View file

@ -1,7 +1,8 @@
import * as Ariakit from '@ariakit/react'; import * as Ariakit from '@ariakit/react';
import React, { useRef, useState } from 'react'; import React, { useRef, useState, useMemo } from 'react';
import { FileSearch, ImageUpIcon, TerminalSquareIcon } from 'lucide-react'; import { FileSearch, ImageUpIcon, TerminalSquareIcon } from 'lucide-react';
import { EToolResources } from 'librechat-data-provider'; import { EToolResources, EModelEndpoint } from 'librechat-data-provider';
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
import { FileUpload, TooltipAnchor, DropdownPopup } from '~/components/ui'; import { FileUpload, TooltipAnchor, DropdownPopup } from '~/components/ui';
import { AttachmentIcon } from '~/components/svg'; import { AttachmentIcon } from '~/components/svg';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
@ -19,6 +20,12 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
const isUploadDisabled = disabled ?? false; const isUploadDisabled = disabled ?? false;
const inputRef = useRef<HTMLInputElement>(null); const inputRef = useRef<HTMLInputElement>(null);
const [isPopoverActive, setIsPopoverActive] = useState(false); const [isPopoverActive, setIsPopoverActive] = useState(false);
const { data: endpointsConfig } = useGetEndpointsQuery();
const capabilities = useMemo(
() => endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [],
[endpointsConfig],
);
const handleUploadClick = (isImage?: boolean) => { const handleUploadClick = (isImage?: boolean) => {
if (!inputRef.current) { if (!inputRef.current) {
@ -30,32 +37,42 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
inputRef.current.accept = ''; inputRef.current.accept = '';
}; };
const dropdownItems = [ const dropdownItems = useMemo(() => {
{ const items = [
label: localize('com_ui_upload_image_input'), {
onClick: () => { label: localize('com_ui_upload_image_input'),
setToolResource?.(undefined); onClick: () => {
handleUploadClick(true); setToolResource?.(undefined);
handleUploadClick(true);
},
icon: <ImageUpIcon className="icon-md" />,
}, },
icon: <ImageUpIcon className="icon-md" />, ];
},
{ if (capabilities.includes(EToolResources.file_search)) {
label: localize('com_ui_upload_file_search'), items.push({
onClick: () => { label: localize('com_ui_upload_file_search'),
setToolResource?.(EToolResources.file_search); onClick: () => {
handleUploadClick(); setToolResource?.(EToolResources.file_search);
}, handleUploadClick();
icon: <FileSearch className="icon-md" />, },
}, icon: <FileSearch className="icon-md" />,
{ });
label: localize('com_ui_upload_code_files'), }
onClick: () => {
setToolResource?.(EToolResources.execute_code); if (capabilities.includes(EToolResources.execute_code)) {
handleUploadClick(); items.push({
}, label: localize('com_ui_upload_code_files'),
icon: <TerminalSquareIcon className="icon-md" />, onClick: () => {
}, setToolResource?.(EToolResources.execute_code);
]; handleUploadClick();
},
icon: <TerminalSquareIcon className="icon-md" />,
});
}
return items;
}, [capabilities, localize, setToolResource]);
const menuTrigger = ( const menuTrigger = (
<TooltipAnchor <TooltipAnchor

View file

@ -0,0 +1,90 @@
import React, { useMemo } from 'react';
import { EModelEndpoint, EToolResources } from 'librechat-data-provider';
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
import { FileSearch, ImageUpIcon, TerminalSquareIcon } from 'lucide-react';
import OGDialogTemplate from '~/components/ui/OGDialogTemplate';
import useLocalize from '~/hooks/useLocalize';
import { OGDialog } from '~/components/ui';
interface DragDropModalProps {
onOptionSelect: (option: string | undefined) => void;
files: File[];
isVisible: boolean;
setShowModal: (showModal: boolean) => void;
}
interface FileOption {
label: string;
value?: EToolResources;
icon: React.JSX.Element;
condition?: boolean;
}
const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragDropModalProps) => {
const localize = useLocalize();
const { data: endpointsConfig } = useGetEndpointsQuery();
const capabilities = useMemo(
() => endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [],
[endpointsConfig],
);
const options = useMemo(() => {
const _options: FileOption[] = [
{
label: localize('com_ui_upload_image_input'),
value: undefined,
icon: <ImageUpIcon className="icon-md" />,
condition: files.every((file) => file.type.startsWith('image/')),
},
];
for (const capability of capabilities) {
if (capability === EToolResources.file_search) {
_options.push({
label: localize('com_ui_upload_file_search'),
value: EToolResources.file_search,
icon: <FileSearch className="icon-md" />,
});
} else if (capability === EToolResources.execute_code) {
_options.push({
label: localize('com_ui_upload_code_files'),
value: EToolResources.execute_code,
icon: <TerminalSquareIcon className="icon-md" />,
});
}
}
return _options;
}, [capabilities, files, localize]);
if (!isVisible) {
return null;
}
return (
<OGDialog open={isVisible} onOpenChange={setShowModal}>
<OGDialogTemplate
title={localize('com_ui_upload_type')}
className="w-11/12 sm:w-[440px] md:w-[400px] lg:w-[360px]"
main={
<div className="flex flex-col gap-2">
{options.map(
(option, index) =>
option.condition !== false && (
<button
key={index}
onClick={() => onOptionSelect(option.value)}
className="flex items-center gap-2 rounded-lg p-2 hover:bg-surface-active-alt"
>
{option.icon}
<span>{option.label}</span>
</button>
),
)}
</div>
}
/>
</OGDialog>
);
};
export default DragDropModal;

View file

@ -0,0 +1,29 @@
import { useDragHelpers } from '~/hooks';
import DragDropOverlay from '~/components/Chat/Input/Files/DragDropOverlay';
import DragDropModal from '~/components/Chat/Input/Files/DragDropModal';
import { cn } from '~/utils';
interface DragDropWrapperProps {
children: React.ReactNode;
className?: string;
}
export default function DragDropWrapper({ children, className }: DragDropWrapperProps) {
const { isOver, canDrop, drop, showModal, setShowModal, draggedFiles, handleOptionSelect } =
useDragHelpers();
const isActive = canDrop && isOver;
return (
<div ref={drop} className={cn('relative flex h-full w-full', className)}>
{children}
{isActive && <DragDropOverlay />}
<DragDropModal
files={draggedFiles}
isVisible={showModal}
setShowModal={setShowModal}
onOptionSelect={handleOptionSelect}
/>
</div>
);
}

View file

@ -3,11 +3,11 @@ import { useEffect, useMemo } from 'react';
import { useGetStartupConfig } from 'librechat-data-provider/react-query'; import { useGetStartupConfig } from 'librechat-data-provider/react-query';
import { FileSources, LocalStorageKeys, getConfigDefaults } from 'librechat-data-provider'; import { FileSources, LocalStorageKeys, getConfigDefaults } from 'librechat-data-provider';
import type { ExtendedFile } from '~/common'; import type { ExtendedFile } from '~/common';
import { useDragHelpers, useSetFilesToDelete } from '~/hooks'; import DragDropWrapper from '~/components/Chat/Input/Files/DragDropWrapper';
import DragDropOverlay from './Input/Files/DragDropOverlay';
import { useDeleteFilesMutation } from '~/data-provider'; import { useDeleteFilesMutation } from '~/data-provider';
import Artifacts from '~/components/Artifacts/Artifacts'; import Artifacts from '~/components/Artifacts/Artifacts';
import { SidePanel } from '~/components/SidePanel'; import { SidePanel } from '~/components/SidePanel';
import { useSetFilesToDelete } from '~/hooks';
import store from '~/store'; import store from '~/store';
const defaultInterface = getConfigDefaults().interface; const defaultInterface = getConfigDefaults().interface;
@ -33,7 +33,6 @@ export default function Presentation({
); );
const setFilesToDelete = useSetFilesToDelete(); const setFilesToDelete = useSetFilesToDelete();
const { isOver, canDrop, drop } = useDragHelpers();
const { mutateAsync } = useDeleteFilesMutation({ const { mutateAsync } = useDeleteFilesMutation({
onSuccess: () => { onSuccess: () => {
@ -66,8 +65,6 @@ export default function Presentation({
mutateAsync({ files }); mutateAsync({ files });
}, [mutateAsync]); }, [mutateAsync]);
const isActive = canDrop && isOver;
const defaultLayout = useMemo(() => { const defaultLayout = useMemo(() => {
const resizableLayout = localStorage.getItem('react-resizable-panels:layout'); const resizableLayout = localStorage.getItem('react-resizable-panels:layout');
return typeof resizableLayout === 'string' ? JSON.parse(resizableLayout) : undefined; return typeof resizableLayout === 'string' ? JSON.parse(resizableLayout) : undefined;
@ -79,20 +76,16 @@ export default function Presentation({
const fullCollapse = useMemo(() => localStorage.getItem('fullPanelCollapse') === 'true', []); const fullCollapse = useMemo(() => localStorage.getItem('fullPanelCollapse') === 'true', []);
const layout = () => ( const layout = () => (
<div className="transition-width relative flex h-full w-full flex-1 flex-col items-stretch overflow-hidden bg-white pt-0 dark:bg-gray-800"> <div className="transition-width relative flex h-full w-full flex-1 flex-col items-stretch overflow-hidden bg-presentation pt-0">
<div className="flex h-full flex-col" role="presentation"> <div className="flex h-full flex-col" role="presentation">
{children} {children}
{isActive && <DragDropOverlay />}
</div> </div>
</div> </div>
); );
if (useSidePanel && !hideSidePanel && interfaceConfig.sidePanel === true) { if (useSidePanel && !hideSidePanel && interfaceConfig.sidePanel === true) {
return ( return (
<div <DragDropWrapper className="relative flex w-full grow overflow-hidden bg-presentation">
ref={drop}
className="relative flex w-full grow overflow-hidden bg-white dark:bg-gray-800"
>
<SidePanel <SidePanel
defaultLayout={defaultLayout} defaultLayout={defaultLayout}
defaultCollapsed={defaultCollapsed} defaultCollapsed={defaultCollapsed}
@ -107,17 +100,16 @@ export default function Presentation({
> >
<main className="flex h-full flex-col" role="main"> <main className="flex h-full flex-col" role="main">
{children} {children}
{isActive && <DragDropOverlay />}
</main> </main>
</SidePanel> </SidePanel>
</div> </DragDropWrapper>
); );
} }
return ( return (
<div ref={drop} className="relative flex w-full grow overflow-hidden bg-white dark:bg-gray-800"> <DragDropWrapper className="relative flex w-full grow overflow-hidden bg-presentation">
{layout()} {layout()}
{panel != null && panel} {panel != null && panel}
</div> </DragDropWrapper>
); );
} }

View file

@ -1,42 +1,76 @@
import { useState, useMemo } from 'react';
import { useDrop } from 'react-dnd'; import { useDrop } from 'react-dnd';
import { useRecoilValue } from 'recoil';
import { NativeTypes } from 'react-dnd-html5-backend'; import { NativeTypes } from 'react-dnd-html5-backend';
import { useQueryClient } from '@tanstack/react-query';
import {
isAgentsEndpoint,
EModelEndpoint,
AgentCapabilities,
QueryKeys,
} from 'librechat-data-provider';
import type * as t from 'librechat-data-provider';
import type { DropTargetMonitor } from 'react-dnd'; import type { DropTargetMonitor } from 'react-dnd';
import useFileHandling from './useFileHandling'; import useFileHandling from './useFileHandling';
import store from '~/store';
export default function useDragHelpers() { export default function useDragHelpers() {
const { files, handleFiles } = useFileHandling(); const queryClient = useQueryClient();
const { handleFiles } = useFileHandling();
const [showModal, setShowModal] = useState(false);
const [draggedFiles, setDraggedFiles] = useState<File[]>([]);
const conversation = useRecoilValue(store.conversationByIndex(0)) || undefined;
const handleOptionSelect = (toolResource: string | undefined) => {
handleFiles(draggedFiles, toolResource);
setShowModal(false);
setDraggedFiles([]);
};
const isAgents = useMemo(
() => isAgentsEndpoint(conversation?.endpoint),
[conversation?.endpoint],
);
const [{ canDrop, isOver }, drop] = useDrop( const [{ canDrop, isOver }, drop] = useDrop(
() => ({ () => ({
accept: [NativeTypes.FILE], accept: [NativeTypes.FILE],
drop(item: { files: File[] }) { drop(item: { files: File[] }) {
console.log('drop', item.files); console.log('drop', item.files);
handleFiles(item.files); if (!isAgents) {
}, handleFiles(item.files);
canDrop() { return;
// console.log('canDrop', item.files, item.items); }
return true;
},
// hover() {
// // console.log('hover', item.files, item.items);
// },
collect: (monitor: DropTargetMonitor) => {
// const item = monitor.getItem() as File[];
// if (item) {
// console.log('collect', item.files, item.items);
// }
return { const endpointsConfig = queryClient.getQueryData<t.TEndpointsConfig>([QueryKeys.endpoints]);
isOver: monitor.isOver(), const agentsConfig = endpointsConfig?.[EModelEndpoint.agents];
canDrop: monitor.canDrop(), const codeEnabled =
}; agentsConfig?.capabilities?.includes(AgentCapabilities.execute_code) === true;
const fileSearchEnabled =
agentsConfig?.capabilities?.includes(AgentCapabilities.file_search) === true;
if (!codeEnabled && !fileSearchEnabled) {
handleFiles(item.files);
return;
}
setDraggedFiles(item.files);
setShowModal(true);
}, },
canDrop: () => true,
collect: (monitor: DropTargetMonitor) => ({
isOver: monitor.isOver(),
canDrop: monitor.canDrop(),
}),
}), }),
[files], [],
); );
return { return {
canDrop, canDrop,
isOver, isOver,
drop, drop,
showModal,
setShowModal,
draggedFiles,
handleOptionSelect,
}; };
} }

View file

@ -187,8 +187,9 @@ const useFileHandling = (params?: UseFileHandling) => {
if (!agent_id) { if (!agent_id) {
formData.append('message_file', 'true'); formData.append('message_file', 'true');
} }
if (toolResource != null) { const tool_resource = extendedFile.tool_resource ?? toolResource;
formData.append('tool_resource', toolResource); if (tool_resource != null) {
formData.append('tool_resource', tool_resource);
} }
if (conversation?.agent_id != null && formData.get('agent_id') == null) { if (conversation?.agent_id != null && formData.get('agent_id') == null) {
formData.append('agent_id', conversation.agent_id); formData.append('agent_id', conversation.agent_id);
@ -327,7 +328,7 @@ const useFileHandling = (params?: UseFileHandling) => {
img.src = preview; img.src = preview;
}; };
const handleFiles = async (_files: FileList | File[]) => { const handleFiles = async (_files: FileList | File[], _toolResource?: string) => {
abortControllerRef.current = new AbortController(); abortControllerRef.current = new AbortController();
const fileList = Array.from(_files); const fileList = Array.from(_files);
/* Validate files */ /* Validate files */
@ -358,9 +359,22 @@ const useFileHandling = (params?: UseFileHandling) => {
size: originalFile.size, size: originalFile.size,
}; };
if (_toolResource != null && _toolResource !== '') {
extendedFile.tool_resource = _toolResource;
}
const isImage = originalFile.type.split('/')[0] === 'image';
const tool_resource =
extendedFile.tool_resource ?? params?.additionalMetadata?.tool_resource ?? toolResource;
if (isAgentsEndpoint(endpoint) && !isImage && tool_resource == null) {
/** Note: this needs to be removed when we can support files to providers */
setError('com_error_files_unsupported_capability');
continue;
}
addFile(extendedFile); addFile(extendedFile);
if (originalFile.type.split('/')[0] === 'image') { if (isImage) {
loadImage(extendedFile, preview); loadImage(extendedFile, preview);
continue; continue;
} }

View file

@ -42,6 +42,7 @@ export default {
com_error_files_dupe: 'Duplicate file detected.', com_error_files_dupe: 'Duplicate file detected.',
com_error_files_validation: 'An error occurred while validating the file.', com_error_files_validation: 'An error occurred while validating the file.',
com_error_files_process: 'An error occurred while processing the file.', com_error_files_process: 'An error occurred while processing the file.',
com_error_files_unsupported_capability: 'No capabilities enabled that support this file type.',
com_error_files_upload: 'An error occurred while uploading the file.', com_error_files_upload: 'An error occurred while uploading the file.',
com_error_files_upload_canceled: com_error_files_upload_canceled:
'The file upload request was canceled. Note: the file upload may still be processing and will need to be manually deleted.', 'The file upload request was canceled. Note: the file upload may still be processing and will need to be manually deleted.',
@ -203,6 +204,7 @@ export default {
com_ui_next: 'Next', com_ui_next: 'Next',
com_ui_stop: 'Stop', com_ui_stop: 'Stop',
com_ui_upload_files: 'Upload files', com_ui_upload_files: 'Upload files',
com_ui_upload_type: 'Select Upload Type',
com_ui_upload_image_input: 'Upload Image', com_ui_upload_image_input: 'Upload Image',
com_ui_upload_file_search: 'Upload for File Search', com_ui_upload_file_search: 'Upload for File Search',
com_ui_upload_code_files: 'Upload for Code Interpreter', com_ui_upload_code_files: 'Upload for Code Interpreter',

View file

@ -39,6 +39,7 @@
--font-size-xl: 1.25rem; --font-size-xl: 1.25rem;
} }
html { html {
--presentation: var(--white);
--text-primary: var(--gray-800); --text-primary: var(--gray-800);
--text-secondary: var(--gray-600); --text-secondary: var(--gray-600);
--text-secondary-alt: var(--gray-500); --text-secondary-alt: var(--gray-500);
@ -92,6 +93,7 @@ html {
--switch-unchecked: 0 0% 58%; --switch-unchecked: 0 0% 58%;
} }
.dark { .dark {
--presentation: var(--gray-800);
--text-primary: var(--gray-100); --text-primary: var(--gray-100);
--text-secondary: var(--gray-300); --text-secondary: var(--gray-300);
--text-secondary-alt: var(--gray-400); --text-secondary-alt: var(--gray-400);

View file

@ -61,6 +61,7 @@ module.exports = {
900: '#031f29', 900: '#031f29',
}, },
'brand-purple': '#ab68ff', 'brand-purple': '#ab68ff',
'presentation': 'var(--presentation)',
'text-primary': 'var(--text-primary)', 'text-primary': 'var(--text-primary)',
'text-secondary': 'var(--text-secondary)', 'text-secondary': 'var(--text-secondary)',
'text-secondary-alt': 'var(--text-secondary-alt)', 'text-secondary-alt': 'var(--text-secondary-alt)',