mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-31 23:58:50 +01:00
Merge branch 'main' into feat/multi-lang-Terms-of-service
This commit is contained in:
commit
7c0324695a
258 changed files with 8260 additions and 3717 deletions
|
|
@ -142,7 +142,7 @@ GOOGLE_KEY=user_provided
|
|||
# GOOGLE_AUTH_HEADER=true
|
||||
|
||||
# Gemini API (AI Studio)
|
||||
# GOOGLE_MODELS=gemini-2.0-flash-exp,gemini-2.0-flash-thinking-exp-1219,gemini-exp-1121,gemini-exp-1114,gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
||||
# GOOGLE_MODELS=gemini-2.5-pro-exp-03-25,gemini-2.0-flash-exp,gemini-2.0-flash-thinking-exp-1219,gemini-exp-1121,gemini-exp-1114,gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
||||
|
||||
# Vertex AI
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
|
||||
|
|
|
|||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -37,6 +37,10 @@ client/public/main.js
|
|||
client/public/main.js.map
|
||||
client/public/main.js.LICENSE.txt
|
||||
|
||||
# Azure Blob Storage Emulator (Azurite)
|
||||
__azurite**
|
||||
__blobstorage__/**/*
|
||||
|
||||
# Dependency directorys
|
||||
# Deployed apps should consider commenting these lines out:
|
||||
# see https://npmjs.org/doc/faq.html#Should-I-check-my-node_modules-folder-into-git
|
||||
|
|
|
|||
|
|
@ -879,13 +879,14 @@ class BaseClient {
|
|||
: await getConvo(this.options.req?.user?.id, message.conversationId);
|
||||
|
||||
const unsetFields = {};
|
||||
const exceptions = new Set(['spec', 'iconURL']);
|
||||
if (existingConvo != null) {
|
||||
this.fetchedConvo = true;
|
||||
for (const key in existingConvo) {
|
||||
if (!key) {
|
||||
continue;
|
||||
}
|
||||
if (excludedKeys.has(key)) {
|
||||
if (excludedKeys.has(key) && !exceptions.has(key)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -198,7 +198,11 @@ class GoogleClient extends BaseClient {
|
|||
*/
|
||||
checkVisionRequest(attachments) {
|
||||
/* Validation vision request */
|
||||
this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision';
|
||||
this.defaultVisionModel =
|
||||
this.options.visionModel ??
|
||||
(!EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)
|
||||
? this.modelOptions.model
|
||||
: 'gemini-pro-vision');
|
||||
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
|
||||
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
||||
|
||||
|
|
|
|||
|
|
@ -226,10 +226,6 @@ class OpenAIClient extends BaseClient {
|
|||
logger.debug('Using Azure endpoint');
|
||||
}
|
||||
|
||||
if (this.useOpenRouter) {
|
||||
this.completionsUrl = 'https://openrouter.ai/api/v1/chat/completions';
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
|
|||
5
api/cache/getLogStores.js
vendored
5
api/cache/getLogStores.js
vendored
|
|
@ -49,6 +49,10 @@ const genTitle = isRedisEnabled
|
|||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
|
||||
|
||||
const s3ExpiryInterval = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
|
||||
|
|
@ -89,6 +93,7 @@ const namespaces = {
|
|||
[CacheKeys.ABORT_KEYS]: abortKeys,
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
[CacheKeys.GEN_TITLE]: genTitle,
|
||||
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
|
||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||
[CacheKeys.MESSAGES]: messages,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ require('winston-daily-rotate-file');
|
|||
|
||||
const logDir = path.join(__dirname, '..', 'logs');
|
||||
|
||||
const { NODE_ENV } = process.env;
|
||||
const { NODE_ENV, DEBUG_LOGGING = false } = process.env;
|
||||
|
||||
const useDebugLogging =
|
||||
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
|
||||
DEBUG_LOGGING === true;
|
||||
|
||||
const levels = {
|
||||
error: 0,
|
||||
|
|
@ -36,9 +40,10 @@ const fileFormat = winston.format.combine(
|
|||
winston.format.splat(),
|
||||
);
|
||||
|
||||
const logLevel = useDebugLogging ? 'debug' : 'error';
|
||||
const transports = [
|
||||
new winston.transports.DailyRotateFile({
|
||||
level: 'debug',
|
||||
level: logLevel,
|
||||
filename: `${logDir}/meiliSync-%DATE%.log`,
|
||||
datePattern: 'YYYY-MM-DD',
|
||||
zippedArchive: true,
|
||||
|
|
@ -48,14 +53,6 @@ const transports = [
|
|||
}),
|
||||
];
|
||||
|
||||
// if (NODE_ENV !== 'production') {
|
||||
// transports.push(
|
||||
// new winston.transports.Console({
|
||||
// format: winston.format.combine(winston.format.colorize(), winston.format.simple()),
|
||||
// }),
|
||||
// );
|
||||
// }
|
||||
|
||||
const consoleFormat = winston.format.combine(
|
||||
winston.format.colorize({ all: true }),
|
||||
winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }),
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ const { redactFormat, redactMessage, debugTraverse, jsonTruncateFormat } = requi
|
|||
|
||||
const logDir = path.join(__dirname, '..', 'logs');
|
||||
|
||||
const { NODE_ENV, DEBUG_LOGGING = true, DEBUG_CONSOLE = false, CONSOLE_JSON = false } = process.env;
|
||||
const { NODE_ENV, DEBUG_LOGGING = true, CONSOLE_JSON = false, DEBUG_CONSOLE = false } = process.env;
|
||||
|
||||
const useConsoleJson =
|
||||
(typeof CONSOLE_JSON === 'string' && CONSOLE_JSON?.toLowerCase() === 'true') ||
|
||||
|
|
@ -15,6 +15,10 @@ const useDebugConsole =
|
|||
(typeof DEBUG_CONSOLE === 'string' && DEBUG_CONSOLE?.toLowerCase() === 'true') ||
|
||||
DEBUG_CONSOLE === true;
|
||||
|
||||
const useDebugLogging =
|
||||
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
|
||||
DEBUG_LOGGING === true;
|
||||
|
||||
const levels = {
|
||||
error: 0,
|
||||
warn: 1,
|
||||
|
|
@ -57,28 +61,9 @@ const transports = [
|
|||
maxFiles: '14d',
|
||||
format: fileFormat,
|
||||
}),
|
||||
// new winston.transports.DailyRotateFile({
|
||||
// level: 'info',
|
||||
// filename: `${logDir}/info-%DATE%.log`,
|
||||
// datePattern: 'YYYY-MM-DD',
|
||||
// zippedArchive: true,
|
||||
// maxSize: '20m',
|
||||
// maxFiles: '14d',
|
||||
// }),
|
||||
];
|
||||
|
||||
// if (NODE_ENV !== 'production') {
|
||||
// transports.push(
|
||||
// new winston.transports.Console({
|
||||
// format: winston.format.combine(winston.format.colorize(), winston.format.simple()),
|
||||
// }),
|
||||
// );
|
||||
// }
|
||||
|
||||
if (
|
||||
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
|
||||
DEBUG_LOGGING === true
|
||||
) {
|
||||
if (useDebugLogging) {
|
||||
transports.push(
|
||||
new winston.transports.DailyRotateFile({
|
||||
level: 'debug',
|
||||
|
|
@ -107,10 +92,16 @@ const consoleFormat = winston.format.combine(
|
|||
}),
|
||||
);
|
||||
|
||||
// Determine console log level
|
||||
let consoleLogLevel = 'info';
|
||||
if (useDebugConsole) {
|
||||
consoleLogLevel = 'debug';
|
||||
}
|
||||
|
||||
if (useDebugConsole) {
|
||||
transports.push(
|
||||
new winston.transports.Console({
|
||||
level: 'debug',
|
||||
level: consoleLogLevel,
|
||||
format: useConsoleJson
|
||||
? winston.format.combine(fileFormat, jsonTruncateFormat(), winston.format.json())
|
||||
: winston.format.combine(fileFormat, debugTraverse),
|
||||
|
|
@ -119,14 +110,14 @@ if (useDebugConsole) {
|
|||
} else if (useConsoleJson) {
|
||||
transports.push(
|
||||
new winston.transports.Console({
|
||||
level: 'info',
|
||||
level: consoleLogLevel,
|
||||
format: winston.format.combine(fileFormat, jsonTruncateFormat(), winston.format.json()),
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
transports.push(
|
||||
new winston.transports.Console({
|
||||
level: 'info',
|
||||
level: consoleLogLevel,
|
||||
format: consoleFormat,
|
||||
}),
|
||||
);
|
||||
|
|
|
|||
|
|
@ -46,6 +46,10 @@ const loadAgent = async ({ req, agent_id }) => {
|
|||
id: agent_id,
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (agent.author.toString() === req.user.id) {
|
||||
return agent;
|
||||
}
|
||||
|
|
@ -122,16 +126,17 @@ const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => {
|
|||
};
|
||||
|
||||
/**
|
||||
* Removes multiple resource files from an agent in a single update.
|
||||
* Removes multiple resource files from an agent using atomic operations.
|
||||
* @param {object} params
|
||||
* @param {string} params.agent_id
|
||||
* @param {Array<{tool_resource: string, file_id: string}>} params.files
|
||||
* @returns {Promise<Agent>} The updated agent.
|
||||
* @throws {Error} If the agent is not found or update fails.
|
||||
*/
|
||||
const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
||||
const searchParameter = { id: agent_id };
|
||||
|
||||
// associate each tool resource with the respective file ids array
|
||||
// Group files to remove by resource
|
||||
const filesByResource = files.reduce((acc, { tool_resource, file_id }) => {
|
||||
if (!acc[tool_resource]) {
|
||||
acc[tool_resource] = [];
|
||||
|
|
@ -140,42 +145,35 @@ const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
|||
return acc;
|
||||
}, {});
|
||||
|
||||
// build the update aggregation pipeline wich removes file ids from tool resources array
|
||||
// and eventually deletes empty tool resources
|
||||
const updateData = [];
|
||||
Object.entries(filesByResource).forEach(([resource, fileIds]) => {
|
||||
const toolResourcePath = `tool_resources.${resource}`;
|
||||
const fileIdsPath = `${toolResourcePath}.file_ids`;
|
||||
|
||||
// file ids removal stage
|
||||
updateData.push({
|
||||
$set: {
|
||||
[fileIdsPath]: {
|
||||
$filter: {
|
||||
input: `$${fileIdsPath}`,
|
||||
cond: { $not: [{ $in: ['$$this', fileIds] }] },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// empty tool resource deletion stage
|
||||
updateData.push({
|
||||
$set: {
|
||||
[toolResourcePath]: {
|
||||
$cond: [{ $eq: [`$${fileIdsPath}`, []] }, '$$REMOVE', `$${toolResourcePath}`],
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
// return the updated agent or throw if no agent matches
|
||||
const updatedAgent = await updateAgent(searchParameter, updateData);
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
} else {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
// Step 1: Atomically remove file IDs using $pull
|
||||
const pullOps = {};
|
||||
const resourcesToCheck = new Set();
|
||||
for (const [resource, fileIds] of Object.entries(filesByResource)) {
|
||||
const fileIdsPath = `tool_resources.${resource}.file_ids`;
|
||||
pullOps[fileIdsPath] = { $in: fileIds };
|
||||
resourcesToCheck.add(resource);
|
||||
}
|
||||
|
||||
const updatePullData = { $pull: pullOps };
|
||||
const agentAfterPull = await Agent.findOneAndUpdate(searchParameter, updatePullData, {
|
||||
new: true,
|
||||
}).lean();
|
||||
|
||||
if (!agentAfterPull) {
|
||||
// Agent might have been deleted concurrently, or never existed.
|
||||
// Check if it existed before trying to throw.
|
||||
const agentExists = await getAgent(searchParameter);
|
||||
if (!agentExists) {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
}
|
||||
// If it existed but findOneAndUpdate returned null, something else went wrong.
|
||||
throw new Error('Failed to update agent during file removal (pull step)');
|
||||
}
|
||||
|
||||
// Return the agent state directly after the $pull operation.
|
||||
// Skipping the $unset step for now to simplify and test core $pull atomicity.
|
||||
// Empty arrays might remain, but the removal itself should be correct.
|
||||
return agentAfterPull;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -157,4 +157,134 @@ describe('Agent Resource File Operations', () => {
|
|||
expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5);
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle concurrent duplicate additions', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
const fileId = uuidv4();
|
||||
|
||||
// Concurrent additions of the same file
|
||||
const additionPromises = Array.from({ length: 5 }).map(() =>
|
||||
addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
}),
|
||||
);
|
||||
|
||||
await Promise.all(additionPromises);
|
||||
|
||||
const updatedAgent = await Agent.findOne({ id: agent.id });
|
||||
expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined();
|
||||
// Should only contain one instance of the fileId
|
||||
expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(1);
|
||||
expect(updatedAgent.tool_resources.test_tool.file_ids[0]).toBe(fileId);
|
||||
});
|
||||
|
||||
test('should handle concurrent add and remove of the same file', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
const fileId = uuidv4();
|
||||
|
||||
// First, ensure the file exists (or test might be trivial if remove runs first)
|
||||
await addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
});
|
||||
|
||||
// Concurrent add (which should be ignored) and remove
|
||||
const operations = [
|
||||
addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
}),
|
||||
removeAgentResourceFiles({
|
||||
agent_id: agent.id,
|
||||
files: [{ tool_resource: 'test_tool', file_id: fileId }],
|
||||
}),
|
||||
];
|
||||
|
||||
await Promise.all(operations);
|
||||
|
||||
const updatedAgent = await Agent.findOne({ id: agent.id });
|
||||
// The final state should ideally be that the file is removed,
|
||||
// but the key point is consistency (not duplicated or error state).
|
||||
// Depending on execution order, the file might remain if the add operation's
|
||||
// findOneAndUpdate runs after the remove operation completes.
|
||||
// A more robust check might be that the length is <= 1.
|
||||
// Given the remove uses an update pipeline, it might be more likely to win.
|
||||
// The final state depends on race condition timing (add or remove might "win").
|
||||
// The critical part is that the state is consistent (no duplicates, no errors).
|
||||
// Assert that the fileId is either present exactly once or not present at all.
|
||||
expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined();
|
||||
const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids;
|
||||
const count = finalFileIds.filter((id) => id === fileId).length;
|
||||
expect(count).toBeLessThanOrEqual(1); // Should be 0 or 1, never more
|
||||
// Optional: Check overall length is consistent with the count
|
||||
if (count === 0) {
|
||||
expect(finalFileIds).toHaveLength(0);
|
||||
} else {
|
||||
expect(finalFileIds).toHaveLength(1);
|
||||
expect(finalFileIds[0]).toBe(fileId);
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle concurrent duplicate removals', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
const fileId = uuidv4();
|
||||
|
||||
// Add the file first
|
||||
await addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
});
|
||||
|
||||
// Concurrent removals of the same file
|
||||
const removalPromises = Array.from({ length: 5 }).map(() =>
|
||||
removeAgentResourceFiles({
|
||||
agent_id: agent.id,
|
||||
files: [{ tool_resource: 'test_tool', file_id: fileId }],
|
||||
}),
|
||||
);
|
||||
|
||||
await Promise.all(removalPromises);
|
||||
|
||||
const updatedAgent = await Agent.findOne({ id: agent.id });
|
||||
// Check if the array is empty or the tool resource itself is removed
|
||||
const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? [];
|
||||
expect(fileIds).toHaveLength(0);
|
||||
expect(fileIds).not.toContain(fileId);
|
||||
});
|
||||
|
||||
test('should handle concurrent removals of different files', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
const fileIds = Array.from({ length: 10 }, () => uuidv4());
|
||||
|
||||
// Add all files first
|
||||
await Promise.all(
|
||||
fileIds.map((fileId) =>
|
||||
addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
}),
|
||||
),
|
||||
);
|
||||
|
||||
// Concurrently remove all files
|
||||
const removalPromises = fileIds.map((fileId) =>
|
||||
removeAgentResourceFiles({
|
||||
agent_id: agent.id,
|
||||
files: [{ tool_resource: 'test_tool', file_id: fileId }],
|
||||
}),
|
||||
);
|
||||
|
||||
await Promise.all(removalPromises);
|
||||
|
||||
const updatedAgent = await Agent.findOne({ id: agent.id });
|
||||
// Check if the array is empty or the tool resource itself is removed
|
||||
const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? [];
|
||||
expect(finalFileIds).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -134,6 +134,28 @@ const deleteFiles = async (file_ids, user) => {
|
|||
return await File.deleteMany(deleteQuery);
|
||||
};
|
||||
|
||||
/**
|
||||
* Batch updates files with new signed URLs in MongoDB
|
||||
*
|
||||
* @param {MongoFile[]} updates - Array of updates in the format { file_id, filepath }
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function batchUpdateFiles(updates) {
|
||||
if (!updates || updates.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bulkOperations = updates.map((update) => ({
|
||||
updateOne: {
|
||||
filter: { file_id: update.file_id },
|
||||
update: { $set: { filepath: update.filepath } },
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await File.bulkWrite(bulkOperations);
|
||||
logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
File,
|
||||
findFileById,
|
||||
|
|
@ -145,4 +167,5 @@ module.exports = {
|
|||
deleteFile,
|
||||
deleteFiles,
|
||||
deleteFileByFilter,
|
||||
batchUpdateFiles,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4,13 +4,8 @@ const {
|
|||
SystemRoles,
|
||||
roleDefaults,
|
||||
PermissionTypes,
|
||||
permissionsSchema,
|
||||
removeNullishValues,
|
||||
agentPermissionsSchema,
|
||||
promptPermissionsSchema,
|
||||
runCodePermissionsSchema,
|
||||
bookmarkPermissionsSchema,
|
||||
multiConvoPermissionsSchema,
|
||||
temporaryChatPermissionsSchema,
|
||||
} = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { roleSchema } = require('@librechat/data-schemas');
|
||||
|
|
@ -20,15 +15,16 @@ const Role = mongoose.model('Role', roleSchema);
|
|||
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role,
|
||||
* create it and return the lean version.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<Object>} A plain object representing the role document.
|
||||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
return cachedRole;
|
||||
|
|
@ -40,8 +36,7 @@ const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
|||
let role = await query.lean().exec();
|
||||
|
||||
if (!role && SystemRoles[roleName]) {
|
||||
role = roleDefaults[roleName];
|
||||
role = await new Role(role).save();
|
||||
role = await new Role(roleDefaults[roleName]).save();
|
||||
await cache.set(roleName, role);
|
||||
return role.toObject();
|
||||
}
|
||||
|
|
@ -60,8 +55,8 @@ const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
|||
* @returns {Promise<TRole>} Updated role document.
|
||||
*/
|
||||
const updateRoleByName = async function (roleName, updates) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const role = await Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
|
|
@ -77,29 +72,20 @@ const updateRoleByName = async function (roleName, updates) {
|
|||
}
|
||||
};
|
||||
|
||||
const permissionSchemas = {
|
||||
[PermissionTypes.AGENTS]: agentPermissionsSchema,
|
||||
[PermissionTypes.PROMPTS]: promptPermissionsSchema,
|
||||
[PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema,
|
||||
[PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema,
|
||||
[PermissionTypes.TEMPORARY_CHAT]: temporaryChatPermissionsSchema,
|
||||
[PermissionTypes.RUN_CODE]: runCodePermissionsSchema,
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates access permissions for a specific role and multiple permission types.
|
||||
* @param {SystemRoles} roleName - The role to update.
|
||||
* @param {string} roleName - The role to update.
|
||||
* @param {Object.<PermissionTypes, Object.<Permissions, boolean>>} permissionsUpdate - Permissions to update and their values.
|
||||
*/
|
||||
async function updateAccessPermissions(roleName, permissionsUpdate) {
|
||||
// Filter and clean the permission updates based on our schema definition.
|
||||
const updates = {};
|
||||
for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) {
|
||||
if (permissionSchemas[permissionType]) {
|
||||
if (permissionsSchema.shape && permissionsSchema.shape[permissionType]) {
|
||||
updates[permissionType] = removeNullishValues(permissions);
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(updates).length === 0) {
|
||||
if (!Object.keys(updates).length) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -109,26 +95,28 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
|
|||
return;
|
||||
}
|
||||
|
||||
const updatedPermissions = {};
|
||||
const currentPermissions = role.permissions || {};
|
||||
const updatedPermissions = { ...currentPermissions };
|
||||
let hasChanges = false;
|
||||
|
||||
for (const [permissionType, permissions] of Object.entries(updates)) {
|
||||
const currentPermissions = role[permissionType] || {};
|
||||
updatedPermissions[permissionType] = { ...currentPermissions };
|
||||
const currentTypePermissions = currentPermissions[permissionType] || {};
|
||||
updatedPermissions[permissionType] = { ...currentTypePermissions };
|
||||
|
||||
for (const [permission, value] of Object.entries(permissions)) {
|
||||
if (currentPermissions[permission] !== value) {
|
||||
if (currentTypePermissions[permission] !== value) {
|
||||
updatedPermissions[permissionType][permission] = value;
|
||||
hasChanges = true;
|
||||
logger.info(
|
||||
`Updating '${roleName}' role ${permissionType} '${permission}' permission from ${currentPermissions[permission]} to: ${value}`,
|
||||
`Updating '${roleName}' role permission '${permissionType}' '${permission}' from ${currentTypePermissions[permission]} to: ${value}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasChanges) {
|
||||
await updateRoleByName(roleName, updatedPermissions);
|
||||
// Update only the permissions field.
|
||||
await updateRoleByName(roleName, { permissions: updatedPermissions });
|
||||
logger.info(`Updated '${roleName}' role permissions`);
|
||||
} else {
|
||||
logger.info(`No changes needed for '${roleName}' role permissions`);
|
||||
|
|
@ -146,30 +134,27 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
|
|||
* @returns {Promise<void>}
|
||||
*/
|
||||
const initializeRoles = async function () {
|
||||
const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER];
|
||||
|
||||
for (const roleName of defaultRoles) {
|
||||
for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) {
|
||||
let role = await Role.findOne({ name: roleName });
|
||||
const defaultPerms = roleDefaults[roleName].permissions;
|
||||
|
||||
if (!role) {
|
||||
// Create new role if it doesn't exist
|
||||
// Create new role if it doesn't exist.
|
||||
role = new Role(roleDefaults[roleName]);
|
||||
} else {
|
||||
// Add missing permission types
|
||||
let isUpdated = false;
|
||||
for (const permType of Object.values(PermissionTypes)) {
|
||||
if (!role[permType]) {
|
||||
role[permType] = roleDefaults[roleName][permType];
|
||||
isUpdated = true;
|
||||
// Ensure role.permissions is defined.
|
||||
role.permissions = role.permissions || {};
|
||||
// For each permission type in defaults, add it if missing.
|
||||
for (const permType of Object.keys(defaultPerms)) {
|
||||
if (role.permissions[permType] == null) {
|
||||
role.permissions[permType] = defaultPerms[permType];
|
||||
}
|
||||
}
|
||||
if (isUpdated) {
|
||||
await role.save();
|
||||
}
|
||||
}
|
||||
await role.save();
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
Role,
|
||||
getRoleByName,
|
||||
|
|
|
|||
|
|
@ -2,22 +2,21 @@ const mongoose = require('mongoose');
|
|||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
SystemRoles,
|
||||
PermissionTypes,
|
||||
roleDefaults,
|
||||
Permissions,
|
||||
roleDefaults,
|
||||
PermissionTypes,
|
||||
} = require('librechat-data-provider');
|
||||
const { updateAccessPermissions, initializeRoles } = require('~/models/Role');
|
||||
const { Role, getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { Role } = require('~/models/Role');
|
||||
|
||||
// Mock the cache
|
||||
jest.mock('~/cache/getLogStores', () => {
|
||||
return jest.fn().mockReturnValue({
|
||||
jest.mock('~/cache/getLogStores', () =>
|
||||
jest.fn().mockReturnValue({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
del: jest.fn(),
|
||||
});
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
let mongoServer;
|
||||
|
||||
|
|
@ -41,10 +40,12 @@ describe('updateAccessPermissions', () => {
|
|||
it('should update permissions when changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -56,8 +57,8 @@ describe('updateAccessPermissions', () => {
|
|||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
|
|
@ -67,10 +68,12 @@ describe('updateAccessPermissions', () => {
|
|||
it('should not update permissions when no changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -82,8 +85,8 @@ describe('updateAccessPermissions', () => {
|
|||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
|
|
@ -92,11 +95,8 @@ describe('updateAccessPermissions', () => {
|
|||
|
||||
it('should handle non-existent roles', async () => {
|
||||
await updateAccessPermissions('NON_EXISTENT_ROLE', {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true },
|
||||
});
|
||||
|
||||
const role = await Role.findOne({ name: 'NON_EXISTENT_ROLE' });
|
||||
expect(role).toBeNull();
|
||||
});
|
||||
|
|
@ -104,21 +104,21 @@ describe('updateAccessPermissions', () => {
|
|||
it('should update only specified permissions', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
SHARED_GLOBAL: true,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: { SHARED_GLOBAL: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
|
|
@ -128,21 +128,21 @@ describe('updateAccessPermissions', () => {
|
|||
it('should handle partial updates', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
USE: false,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: { USE: false },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: false,
|
||||
|
|
@ -152,13 +152,9 @@ describe('updateAccessPermissions', () => {
|
|||
it('should update multiple permission types at once', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: {
|
||||
USE: true,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
[PermissionTypes.BOOKMARKS]: { USE: true },
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -167,24 +163,20 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { USE: false },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
expect(updatedRole[PermissionTypes.BOOKMARKS]).toEqual({
|
||||
USE: false,
|
||||
});
|
||||
expect(updatedRole.permissions[PermissionTypes.BOOKMARKS]).toEqual({ USE: false });
|
||||
});
|
||||
|
||||
it('should handle updates for a single permission type', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -192,8 +184,8 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.PROMPTS]: { USE: false, SHARED_GLOBAL: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: true,
|
||||
|
|
@ -203,33 +195,25 @@ describe('updateAccessPermissions', () => {
|
|||
it('should update MULTI_CONVO permissions', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: false,
|
||||
permissions: {
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
|
||||
it('should update MULTI_CONVO permissions along with other permission types', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -238,35 +222,29 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
|
||||
it('should not update MULTI_CONVO permissions when no changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
permissions: {
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -278,65 +256,69 @@ describe('initializeRoles', () => {
|
|||
it('should create default roles if they do not exist', async () => {
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
|
||||
expect(adminRole).toBeTruthy();
|
||||
expect(userRole).toBeTruthy();
|
||||
|
||||
// Check if all permission types exist
|
||||
// Check if all permission types exist in the permissions field
|
||||
Object.values(PermissionTypes).forEach((permType) => {
|
||||
expect(adminRole[permType]).toBeDefined();
|
||||
expect(userRole[permType]).toBeDefined();
|
||||
expect(adminRole.permissions[permType]).toBeDefined();
|
||||
expect(userRole.permissions[permType]).toBeDefined();
|
||||
});
|
||||
|
||||
// Check if permissions match defaults (example for ADMIN role)
|
||||
expect(adminRole[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true);
|
||||
expect(adminRole[PermissionTypes.BOOKMARKS].USE).toBe(true);
|
||||
expect(adminRole[PermissionTypes.AGENTS].CREATE).toBe(true);
|
||||
// Example: Check default values for ADMIN role
|
||||
expect(adminRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.BOOKMARKS].USE).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].CREATE).toBe(true);
|
||||
});
|
||||
|
||||
it('should not modify existing permissions for existing roles', async () => {
|
||||
const customUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: {
|
||||
[Permissions.USE]: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(customUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.PROMPTS]).toEqual(customUserRole[PermissionTypes.PROMPTS]);
|
||||
expect(userRole[PermissionTypes.BOOKMARKS]).toEqual(customUserRole[PermissionTypes.BOOKMARKS]);
|
||||
expect(userRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.PROMPTS]).toEqual(
|
||||
customUserRole.permissions[PermissionTypes.PROMPTS],
|
||||
);
|
||||
expect(userRole.permissions[PermissionTypes.BOOKMARKS]).toEqual(
|
||||
customUserRole.permissions[PermissionTypes.BOOKMARKS],
|
||||
);
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
});
|
||||
|
||||
it('should add new permission types to existing roles', async () => {
|
||||
const partialUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS],
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(partialUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle multiple runs without duplicating or modifying data', async () => {
|
||||
|
|
@ -349,72 +331,73 @@ describe('initializeRoles', () => {
|
|||
expect(adminRoles).toHaveLength(1);
|
||||
expect(userRoles).toHaveLength(1);
|
||||
|
||||
const adminRole = adminRoles[0].toObject();
|
||||
const userRole = userRoles[0].toObject();
|
||||
|
||||
// Check if all permission types exist
|
||||
const adminPerms = adminRoles[0].toObject().permissions;
|
||||
const userPerms = userRoles[0].toObject().permissions;
|
||||
Object.values(PermissionTypes).forEach((permType) => {
|
||||
expect(adminRole[permType]).toBeDefined();
|
||||
expect(userRole[permType]).toBeDefined();
|
||||
expect(adminPerms[permType]).toBeDefined();
|
||||
expect(userPerms[permType]).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
it('should update roles with missing permission types from roleDefaults', async () => {
|
||||
const partialAdminRole = {
|
||||
name: SystemRoles.ADMIN,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.ADMIN][PermissionTypes.BOOKMARKS],
|
||||
};
|
||||
|
||||
await new Role(partialAdminRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
|
||||
expect(adminRole[PermissionTypes.PROMPTS]).toEqual(partialAdminRole[PermissionTypes.PROMPTS]);
|
||||
expect(adminRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
expect(adminRole.permissions[PermissionTypes.PROMPTS]).toEqual(
|
||||
partialAdminRole.permissions[PermissionTypes.PROMPTS],
|
||||
);
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
});
|
||||
|
||||
it('should include MULTI_CONVO permissions when creating default roles', async () => {
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
|
||||
expect(adminRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
|
||||
// Check if MULTI_CONVO permissions match defaults
|
||||
expect(adminRole[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.ADMIN][PermissionTypes.MULTI_CONVO].USE,
|
||||
expect(adminRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.MULTI_CONVO].USE,
|
||||
);
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.USER][PermissionTypes.MULTI_CONVO].USE,
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.MULTI_CONVO].USE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add MULTI_CONVO permissions to existing roles without them', async () => {
|
||||
const partialUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS],
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(partialUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBeDefined();
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -8,39 +8,135 @@ const Balance = require('./Balance');
|
|||
const cancelRate = 1.15;
|
||||
|
||||
/**
|
||||
* Updates a user's token balance based on a transaction.
|
||||
*
|
||||
* Updates a user's token balance based on a transaction using optimistic concurrency control
|
||||
* without schema changes. Compatible with DocumentDB.
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} params - The function parameters.
|
||||
* @param {string} params.user - The user ID.
|
||||
* @param {string|mongoose.Types.ObjectId} params.user - The user ID.
|
||||
* @param {number} params.incrementValue - The value to increment the balance by (can be negative).
|
||||
* @param {import('mongoose').UpdateQuery<import('@librechat/data-schemas').IBalance>['$set']} params.setValues
|
||||
* @returns {Promise<Object>} Returns the updated balance response.
|
||||
* @param {import('mongoose').UpdateQuery<import('@librechat/data-schemas').IBalance>['$set']} [params.setValues] - Optional additional fields to set.
|
||||
* @returns {Promise<Object>} Returns the updated balance document (lean).
|
||||
* @throws {Error} Throws an error if the update fails after multiple retries.
|
||||
*/
|
||||
const updateBalance = async ({ user, incrementValue, setValues }) => {
|
||||
// Use findOneAndUpdate with a conditional update to make the balance update atomic
|
||||
// This prevents race conditions when multiple transactions are processed concurrently
|
||||
const balanceResponse = await Balance.findOneAndUpdate(
|
||||
{ user },
|
||||
[
|
||||
{
|
||||
$set: {
|
||||
tokenCredits: {
|
||||
$cond: {
|
||||
if: { $lt: [{ $add: ['$tokenCredits', incrementValue] }, 0] },
|
||||
then: 0,
|
||||
else: { $add: ['$tokenCredits', incrementValue] },
|
||||
},
|
||||
},
|
||||
...setValues,
|
||||
},
|
||||
},
|
||||
],
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
let maxRetries = 10; // Number of times to retry on conflict
|
||||
let delay = 50; // Initial retry delay in ms
|
||||
let lastError = null;
|
||||
|
||||
return balanceResponse;
|
||||
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
||||
let currentBalanceDoc;
|
||||
try {
|
||||
// 1. Read the current document state
|
||||
currentBalanceDoc = await Balance.findOne({ user }).lean();
|
||||
const currentCredits = currentBalanceDoc ? currentBalanceDoc.tokenCredits : 0;
|
||||
|
||||
// 2. Calculate the desired new state
|
||||
const potentialNewCredits = currentCredits + incrementValue;
|
||||
const newCredits = Math.max(0, potentialNewCredits); // Ensure balance doesn't go below zero
|
||||
|
||||
// 3. Prepare the update payload
|
||||
const updatePayload = {
|
||||
$set: {
|
||||
tokenCredits: newCredits,
|
||||
...(setValues || {}), // Merge other values to set
|
||||
},
|
||||
};
|
||||
|
||||
// 4. Attempt the conditional update or upsert
|
||||
let updatedBalance = null;
|
||||
if (currentBalanceDoc) {
|
||||
// --- Document Exists: Perform Conditional Update ---
|
||||
// Try to update only if the tokenCredits match the value we read (currentCredits)
|
||||
updatedBalance = await Balance.findOneAndUpdate(
|
||||
{
|
||||
user: user,
|
||||
tokenCredits: currentCredits, // Optimistic lock: condition based on the read value
|
||||
},
|
||||
updatePayload,
|
||||
{
|
||||
new: true, // Return the modified document
|
||||
// lean: true, // .lean() is applied after query execution in Mongoose >= 6
|
||||
},
|
||||
).lean(); // Use lean() for plain JS object
|
||||
|
||||
if (updatedBalance) {
|
||||
// Success! The update was applied based on the expected current state.
|
||||
return updatedBalance;
|
||||
}
|
||||
// If updatedBalance is null, it means tokenCredits changed between read and write (conflict).
|
||||
lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`);
|
||||
// Proceed to retry logic below.
|
||||
} else {
|
||||
// --- Document Does Not Exist: Perform Conditional Upsert ---
|
||||
// Try to insert the document, but only if it still doesn't exist.
|
||||
// Using tokenCredits: {$exists: false} helps prevent race conditions where
|
||||
// another process creates the doc between our findOne and findOneAndUpdate.
|
||||
try {
|
||||
updatedBalance = await Balance.findOneAndUpdate(
|
||||
{
|
||||
user: user,
|
||||
// Attempt to match only if the document doesn't exist OR was just created
|
||||
// without tokenCredits (less likely but possible). A simple { user } filter
|
||||
// might also work, relying on the retry for conflicts.
|
||||
// Let's use a simpler filter and rely on retry for races.
|
||||
// tokenCredits: { $exists: false } // This condition might be too strict if doc exists with 0 credits
|
||||
},
|
||||
updatePayload,
|
||||
{
|
||||
upsert: true, // Create if doesn't exist
|
||||
new: true, // Return the created/updated document
|
||||
// setDefaultsOnInsert: true, // Ensure schema defaults are applied on insert
|
||||
// lean: true,
|
||||
},
|
||||
).lean();
|
||||
|
||||
if (updatedBalance) {
|
||||
// Upsert succeeded (likely created the document)
|
||||
return updatedBalance;
|
||||
}
|
||||
// If null, potentially a rare race condition during upsert. Retry should handle it.
|
||||
lastError = new Error(
|
||||
`Upsert race condition suspected for user ${user} on attempt ${attempt}.`,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error.code === 11000) {
|
||||
// E11000 duplicate key error on index
|
||||
// This means another process created the document *just* before our upsert.
|
||||
// It's a concurrency conflict during creation. We should retry.
|
||||
lastError = error; // Store the error
|
||||
// Proceed to retry logic below.
|
||||
} else {
|
||||
// Different error, rethrow
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
} // End if/else (document exists?)
|
||||
} catch (error) {
|
||||
// Catch errors from findOne or unexpected findOneAndUpdate errors
|
||||
logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error);
|
||||
lastError = error; // Store the error
|
||||
// Consider stopping retries for non-transient errors, but for now, we retry.
|
||||
}
|
||||
|
||||
// If we reached here, it means the update failed (conflict or error), wait and retry
|
||||
if (attempt < maxRetries) {
|
||||
const jitter = Math.random() * delay * 0.5; // Add jitter to delay
|
||||
await new Promise((resolve) => setTimeout(resolve, delay + jitter));
|
||||
delay = Math.min(delay * 2, 2000); // Exponential backoff with cap
|
||||
}
|
||||
} // End for loop (retries)
|
||||
|
||||
// If loop finishes without success, throw the last encountered error or a generic one
|
||||
logger.error(
|
||||
`[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`,
|
||||
);
|
||||
throw (
|
||||
lastError ||
|
||||
new Error(
|
||||
`Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`,
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
/** Method to calculate and set the tokenValue for a transaction */
|
||||
|
|
|
|||
|
|
@ -5,6 +5,10 @@ const { getMultiplier } = require('./tx');
|
|||
const { logger } = require('~/config');
|
||||
const Balance = require('./Balance');
|
||||
|
||||
function isInvalidDate(date) {
|
||||
return isNaN(date);
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple check method that calculates token cost and returns balance info.
|
||||
* The auto-refill logic has been moved to balanceMethods.js to prevent circular dependencies.
|
||||
|
|
@ -48,13 +52,12 @@ const checkBalanceRecord = async function ({
|
|||
// Only perform auto-refill if spending would bring the balance to 0 or below
|
||||
if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) {
|
||||
const lastRefillDate = new Date(record.lastRefill);
|
||||
const nextRefillDate = addIntervalToDate(
|
||||
lastRefillDate,
|
||||
record.refillIntervalValue,
|
||||
record.refillIntervalUnit,
|
||||
);
|
||||
const now = new Date();
|
||||
if (now >= nextRefillDate) {
|
||||
if (
|
||||
isInvalidDate(lastRefillDate) ||
|
||||
now >=
|
||||
addIntervalToDate(lastRefillDate, record.refillIntervalValue, record.refillIntervalUnit)
|
||||
) {
|
||||
try {
|
||||
/** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */
|
||||
const result = await Transaction.createAutoRefillTransaction({
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
const _ = require('lodash');
|
||||
const mongoose = require('mongoose');
|
||||
const { MeiliSearch } = require('meilisearch');
|
||||
const { parseTextParts, ContentTypes } = require('librechat-data-provider');
|
||||
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
|
||||
const logger = require('~/config/meiliLogger');
|
||||
|
||||
|
|
@ -238,10 +239,7 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) {
|
|||
}
|
||||
|
||||
if (object.content && Array.isArray(object.content)) {
|
||||
object.text = object.content
|
||||
.filter((item) => item.type === 'text' && item.text && item.text.value)
|
||||
.map((item) => item.text.value)
|
||||
.join(' ');
|
||||
object.text = parseTextParts(object.content);
|
||||
delete object.content;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -459,7 +459,7 @@ describe('spendTokens', () => {
|
|||
|
||||
it('should handle multiple concurrent transactions correctly with a high balance', async () => {
|
||||
// Create a balance with a high amount
|
||||
const initialBalance = 1000000;
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: initialBalance,
|
||||
|
|
@ -470,8 +470,9 @@ describe('spendTokens', () => {
|
|||
const context = 'message';
|
||||
const model = 'gpt-4';
|
||||
|
||||
// Create 10 usage records to simulate multiple transactions
|
||||
const collectedUsage = Array.from({ length: 10 }, (_, i) => ({
|
||||
const amount = 50;
|
||||
// Create `amount` of usage records to simulate multiple transactions
|
||||
const collectedUsage = Array.from({ length: amount }, (_, i) => ({
|
||||
model,
|
||||
input_tokens: 100 + i * 10, // Increasing input tokens
|
||||
output_tokens: 50 + i * 5, // Increasing output tokens
|
||||
|
|
@ -591,6 +592,80 @@ describe('spendTokens', () => {
|
|||
expect(Math.abs(totalTokenValue)).toBeCloseTo(actualSpend, -3); // Allow for larger differences
|
||||
});
|
||||
|
||||
// Add this new test case
|
||||
it('should handle multiple concurrent balance increases correctly', async () => {
|
||||
// Start with zero balance
|
||||
const initialBalance = 0;
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: initialBalance,
|
||||
});
|
||||
|
||||
const numberOfRefills = 25;
|
||||
const refillAmount = 1000;
|
||||
|
||||
const promises = [];
|
||||
for (let i = 0; i < numberOfRefills; i++) {
|
||||
promises.push(
|
||||
Transaction.createAutoRefillTransaction({
|
||||
user: userId,
|
||||
tokenType: 'credits',
|
||||
context: 'concurrent-refill-test',
|
||||
rawAmount: refillAmount,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Wait for all refill transactions to complete
|
||||
const results = await Promise.all(promises);
|
||||
|
||||
// Verify final balance
|
||||
const finalBalance = await Balance.findOne({ user: userId });
|
||||
expect(finalBalance).toBeDefined();
|
||||
|
||||
// The final balance should be the initial balance plus the sum of all refills
|
||||
const expectedFinalBalance = initialBalance + numberOfRefills * refillAmount;
|
||||
|
||||
console.log('Initial balance (Increase Test):', initialBalance);
|
||||
console.log(`Performed ${numberOfRefills} refills of ${refillAmount} each.`);
|
||||
console.log('Expected final balance (Increase Test):', expectedFinalBalance);
|
||||
console.log('Actual final balance (Increase Test):', finalBalance.tokenCredits);
|
||||
|
||||
// Use toBeCloseTo for safety, though toBe should work for integer math
|
||||
expect(finalBalance.tokenCredits).toBeCloseTo(expectedFinalBalance, 0);
|
||||
|
||||
// Verify all transactions were created
|
||||
const transactions = await Transaction.find({
|
||||
user: userId,
|
||||
context: 'concurrent-refill-test',
|
||||
});
|
||||
|
||||
// We should have one transaction for each refill attempt
|
||||
expect(transactions.length).toBe(numberOfRefills);
|
||||
|
||||
// Optional: Verify the sum of increments from the results matches the balance change
|
||||
const totalIncrementReported = results.reduce((sum, result) => {
|
||||
// Assuming createAutoRefillTransaction returns an object with the increment amount
|
||||
// Adjust this based on the actual return structure.
|
||||
// Let's assume it returns { balance: newBalance, transaction: { rawAmount: ... } }
|
||||
// Or perhaps we check the transaction.rawAmount directly
|
||||
return sum + (result?.transaction?.rawAmount || 0);
|
||||
}, 0);
|
||||
console.log('Total increment reported by results:', totalIncrementReported);
|
||||
expect(totalIncrementReported).toBe(expectedFinalBalance - initialBalance);
|
||||
|
||||
// Optional: Check the sum of tokenValue from saved transactions
|
||||
let totalTokenValueFromDb = 0;
|
||||
transactions.forEach((tx) => {
|
||||
// For refills, rawAmount is positive, and tokenValue might be calculated based on it
|
||||
// Let's assume tokenValue directly reflects the increment for simplicity here
|
||||
// If calculation is involved, adjust accordingly
|
||||
totalTokenValueFromDb += tx.rawAmount; // Or tx.tokenValue if that holds the increment
|
||||
});
|
||||
console.log('Total rawAmount from DB transactions:', totalTokenValueFromDb);
|
||||
expect(totalTokenValueFromDb).toBeCloseTo(expectedFinalBalance - initialBalance, 0);
|
||||
});
|
||||
|
||||
it('should create structured transactions for both prompt and completion tokens', async () => {
|
||||
// Create a balance for the user
|
||||
await Balance.create({
|
||||
|
|
|
|||
|
|
@ -107,8 +107,10 @@ const tokenValues = Object.assign(
|
|||
so this was from https://artificialanalysis.ai/models/command-light/providers */
|
||||
command: { prompt: 0.38, completion: 0.38 },
|
||||
'gemini-2.0-flash-lite': { prompt: 0.075, completion: 0.3 },
|
||||
'gemini-2.0-flash': { prompt: 0.1, completion: 0.7 },
|
||||
'gemini-2.0-flash': { prompt: 0.1, completion: 0.4 },
|
||||
'gemini-2.0': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing
|
||||
'gemini-2.5-pro-preview-03-25': { prompt: 1.25, completion: 10 },
|
||||
'gemini-2.5': { prompt: 0, completion: 0 }, // Free for a period of time
|
||||
'gemini-1.5-flash-8b': { prompt: 0.075, completion: 0.3 },
|
||||
'gemini-1.5-flash': { prompt: 0.15, completion: 0.6 },
|
||||
'gemini-1.5': { prompt: 2.5, completion: 10 },
|
||||
|
|
@ -122,6 +124,12 @@ const tokenValues = Object.assign(
|
|||
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
||||
'mistral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'pixtral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'mistral-saba': { prompt: 0.2, completion: 0.6 },
|
||||
codestral: { prompt: 0.3, completion: 0.9 },
|
||||
'ministral-8b': { prompt: 0.1, completion: 0.1 },
|
||||
'ministral-3b': { prompt: 0.04, completion: 0.04 },
|
||||
},
|
||||
bedrockValues,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@
|
|||
"@langchain/google-genai": "^0.1.11",
|
||||
"@langchain/google-vertexai": "^0.2.2",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.3.94",
|
||||
"@librechat/agents": "^2.3.95",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"axios": "^1.8.2",
|
||||
|
|
@ -104,7 +104,7 @@
|
|||
"passport-ldapauth": "^3.0.1",
|
||||
"passport-local": "^1.0.0",
|
||||
"rate-limit-redis": "^4.2.0",
|
||||
"sharp": "^0.32.6",
|
||||
"sharp": "^0.33.5",
|
||||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
const { FileSources } = require('librechat-data-provider');
|
||||
const {
|
||||
Balance,
|
||||
getFiles,
|
||||
updateUser,
|
||||
deleteFiles,
|
||||
deleteConvos,
|
||||
deletePresets,
|
||||
|
|
@ -12,6 +14,7 @@ const User = require('~/models/User');
|
|||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { deleteAllSharedLinks } = require('~/models/Share');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
|
|
@ -19,8 +22,23 @@ const { Transaction } = require('~/models/Transaction');
|
|||
const { logger } = require('~/config');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
/** @type {MongoUser} */
|
||||
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
|
||||
delete userData.totpSecret;
|
||||
if (req.app.locals.fileStrategy === FileSources.s3 && userData.avatar) {
|
||||
const avatarNeedsRefresh = needsRefresh(userData.avatar, 3600);
|
||||
if (!avatarNeedsRefresh) {
|
||||
return res.status(200).send(userData);
|
||||
}
|
||||
const originalAvatar = userData.avatar;
|
||||
try {
|
||||
userData.avatar = await getNewS3URL(userData.avatar);
|
||||
await updateUser(userData.id, { avatar: userData.avatar });
|
||||
} catch (error) {
|
||||
userData.avatar = originalAvatar;
|
||||
logger.error('Error getting new S3 URL for avatar:', error);
|
||||
}
|
||||
}
|
||||
res.status(200).send(userData);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -932,7 +932,14 @@ class AgentClient extends BaseClient {
|
|||
};
|
||||
let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint];
|
||||
if (!endpointConfig) {
|
||||
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
|
||||
try {
|
||||
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
if (
|
||||
endpointConfig &&
|
||||
|
|
|
|||
|
|
@ -43,6 +43,12 @@ async function createRun({
|
|||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** Resolves Mistral type strictness due to new OpenAI usage field */
|
||||
if (agent.endpoint?.toLowerCase().includes(KnownEndpoints.mistral)) {
|
||||
llmConfig.streamUsage = false;
|
||||
llmConfig.usage = true;
|
||||
}
|
||||
|
||||
/** @type {'reasoning_content' | 'reasoning'} */
|
||||
let reasoningKey;
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const {
|
|||
Tools,
|
||||
Constants,
|
||||
FileContext,
|
||||
FileSources,
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
|
|
@ -17,9 +18,10 @@ const {
|
|||
} = require('~/models/Agent');
|
||||
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||
const { updateAction, getActions } = require('~/models/Action');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { updateAgentProjects } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -102,6 +104,14 @@ const getAgentHandler = async (req, res) => {
|
|||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
if (agent.avatar && agent.avatar?.source === FileSources.s3) {
|
||||
const originalUrl = agent.avatar.filepath;
|
||||
agent.avatar.filepath = await refreshS3Url(agent.avatar);
|
||||
if (originalUrl !== agent.avatar.filepath) {
|
||||
await updateAgent({ id }, { avatar: agent.avatar });
|
||||
}
|
||||
}
|
||||
|
||||
agent.author = agent.author.toString();
|
||||
agent.isCollaborative = !!agent.isCollaborative;
|
||||
|
||||
|
|
|
|||
|
|
@ -148,6 +148,13 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
|||
return { abortController, onStart };
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {ServerResponse} res
|
||||
* @param {ServerRequest} req
|
||||
* @param {Error | unknown} error
|
||||
* @param {Partial<TMessage> & { partialText?: string }} data
|
||||
* @returns { Promise<void> }
|
||||
*/
|
||||
const handleAbortError = async (res, req, error, data) => {
|
||||
if (error?.message?.includes('base64')) {
|
||||
logger.error('[handleAbortError] Error in base64 encoding', {
|
||||
|
|
@ -178,17 +185,30 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} partialText
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const respondWithError = async (partialText) => {
|
||||
const endpointOption = req.body?.endpointOption;
|
||||
let options = {
|
||||
sender,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
text: errorText,
|
||||
shouldSaveMessage: true,
|
||||
user: req.user.id,
|
||||
shouldSaveMessage: true,
|
||||
spec: endpointOption?.spec,
|
||||
iconURL: endpointOption?.iconURL,
|
||||
modelLabel: endpointOption?.modelLabel,
|
||||
model: endpointOption?.modelOptions?.model || req.body?.model,
|
||||
};
|
||||
|
||||
if (req.body?.agent_id) {
|
||||
options.agent_id = req.body.agent_id;
|
||||
}
|
||||
|
||||
if (partialText) {
|
||||
options = {
|
||||
...options,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ const concurrentLimiter = require('./concurrentLimiter');
|
|||
const validateEndpoint = require('./validateEndpoint');
|
||||
const requireLocalAuth = require('./requireLocalAuth');
|
||||
const canDeleteAccount = require('./canDeleteAccount');
|
||||
const setBalanceConfig = require('./setBalanceConfig');
|
||||
const requireLdapAuth = require('./requireLdapAuth');
|
||||
const abortMiddleware = require('./abortMiddleware');
|
||||
const checkInviteUser = require('./checkInviteUser');
|
||||
|
|
@ -41,6 +42,7 @@ module.exports = {
|
|||
requireLocalAuth,
|
||||
canDeleteAccount,
|
||||
validateEndpoint,
|
||||
setBalanceConfig,
|
||||
concurrentLimiter,
|
||||
checkDomainAllowed,
|
||||
validateMessageReq,
|
||||
|
|
|
|||
|
|
@ -1,39 +1,41 @@
|
|||
const axios = require('axios');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
async function moderateText(req, res, next) {
|
||||
if (process.env.OPENAI_MODERATION === 'true') {
|
||||
try {
|
||||
const { text } = req.body;
|
||||
if (!isEnabled(process.env.OPENAI_MODERATION)) {
|
||||
return next();
|
||||
}
|
||||
try {
|
||||
const { text } = req.body;
|
||||
|
||||
const response = await axios.post(
|
||||
process.env.OPENAI_MODERATION_REVERSE_PROXY || 'https://api.openai.com/v1/moderations',
|
||||
{
|
||||
input: text,
|
||||
const response = await axios.post(
|
||||
process.env.OPENAI_MODERATION_REVERSE_PROXY || 'https://api.openai.com/v1/moderations',
|
||||
{
|
||||
input: text,
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${process.env.OPENAI_MODERATION_API_KEY}`,
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${process.env.OPENAI_MODERATION_API_KEY}`,
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
const results = response.data.results;
|
||||
const flagged = results.some((result) => result.flagged);
|
||||
const results = response.data.results;
|
||||
const flagged = results.some((result) => result.flagged);
|
||||
|
||||
if (flagged) {
|
||||
const type = ErrorTypes.MODERATION;
|
||||
const errorMessage = { type };
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in moderateText:', error);
|
||||
const errorMessage = 'error in moderation check';
|
||||
if (flagged) {
|
||||
const type = ErrorTypes.MODERATION;
|
||||
const errorMessage = { type };
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in moderateText:', error);
|
||||
const errorMessage = 'error in moderation check';
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
}
|
||||
next();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ const checkAccess = async (user, permissionType, permissions, bodyProps = {}, ch
|
|||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role[permissionType]) {
|
||||
if (role && role.permissions && role.permissions[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role[permissionType][permission]) {
|
||||
if (role.permissions[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
91
api/server/middleware/setBalanceConfig.js
Normal file
91
api/server/middleware/setBalanceConfig.js
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const Balance = require('~/models/Balance');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Middleware to synchronize user balance settings with current balance configuration.
|
||||
* @function
|
||||
* @param {Object} req - Express request object containing user information.
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {import('express').NextFunction} next - Next middleware function.
|
||||
*/
|
||||
const setBalanceConfig = async (req, res, next) => {
|
||||
try {
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
if (!balanceConfig?.enabled) {
|
||||
return next();
|
||||
}
|
||||
if (balanceConfig.startBalance == null) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const userId = req.user._id;
|
||||
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
|
||||
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord);
|
||||
|
||||
if (Object.keys(updateFields).length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
await Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: updateFields },
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('Error setting user balance:', error);
|
||||
next(error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Build an object containing fields that need updating
|
||||
* @param {Object} config - The balance configuration
|
||||
* @param {Object|null} userRecord - The user's current balance record, if any
|
||||
* @returns {Object} Fields that need updating
|
||||
*/
|
||||
function buildUpdateFields(config, userRecord) {
|
||||
const updateFields = {};
|
||||
|
||||
// Ensure user record has the required fields
|
||||
if (!userRecord) {
|
||||
updateFields.user = userRecord?.user;
|
||||
updateFields.tokenCredits = config.startBalance;
|
||||
}
|
||||
|
||||
if (userRecord?.tokenCredits == null && config.startBalance != null) {
|
||||
updateFields.tokenCredits = config.startBalance;
|
||||
}
|
||||
|
||||
const isAutoRefillConfigValid =
|
||||
config.autoRefillEnabled &&
|
||||
config.refillIntervalValue != null &&
|
||||
config.refillIntervalUnit != null &&
|
||||
config.refillAmount != null;
|
||||
|
||||
if (!isAutoRefillConfigValid) {
|
||||
return updateFields;
|
||||
}
|
||||
|
||||
if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) {
|
||||
updateFields.autoRefillEnabled = config.autoRefillEnabled;
|
||||
}
|
||||
|
||||
if (userRecord?.refillIntervalValue !== config.refillIntervalValue) {
|
||||
updateFields.refillIntervalValue = config.refillIntervalValue;
|
||||
}
|
||||
|
||||
if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) {
|
||||
updateFields.refillIntervalUnit = config.refillIntervalUnit;
|
||||
}
|
||||
|
||||
if (userRecord?.refillAmount !== config.refillAmount) {
|
||||
updateFields.refillAmount = config.refillAmount;
|
||||
}
|
||||
|
||||
return updateFields;
|
||||
}
|
||||
|
||||
module.exports = setBalanceConfig;
|
||||
|
|
@ -3,6 +3,7 @@ const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
generateCheckAccess,
|
||||
validateConvoAccess,
|
||||
|
|
@ -14,6 +15,7 @@ const addTitle = require('~/server/services/Endpoints/agents/title');
|
|||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(moderateText);
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
|
|
|
|||
|
|
@ -1,21 +1,40 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
// concurrentLimiter,
|
||||
// messageIpLimiter,
|
||||
// messageUserLimiter,
|
||||
messageIpLimiter,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { v1 } = require('./v1');
|
||||
const chat = require('./chat');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
router.use('/', v1);
|
||||
router.use('/chat', chat);
|
||||
|
||||
const chatRouter = express.Router();
|
||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
chatRouter.use(concurrentLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
||||
chatRouter.use(messageIpLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||
chatRouter.use(messageUserLimiter);
|
||||
}
|
||||
|
||||
chatRouter.use('/', chat);
|
||||
router.use('/chat', chatRouter);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,4 @@
|
|||
const express = require('express');
|
||||
const openAI = require('./openAI');
|
||||
const custom = require('./custom');
|
||||
const google = require('./google');
|
||||
const anthropic = require('./anthropic');
|
||||
const gptPlugins = require('./gptPlugins');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
uaParser,
|
||||
|
|
@ -15,6 +9,12 @@ const {
|
|||
messageUserLimiter,
|
||||
validateConvoAccess,
|
||||
} = require('~/server/middleware');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const gptPlugins = require('./gptPlugins');
|
||||
const anthropic = require('./anthropic');
|
||||
const custom = require('./custom');
|
||||
const google = require('./google');
|
||||
const openAI = require('./openAI');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ const {
|
|||
checkInviteUser,
|
||||
registerLimiter,
|
||||
requireLdapAuth,
|
||||
setBalanceConfig,
|
||||
requireLocalAuth,
|
||||
resetPasswordLimiter,
|
||||
validateRegistration,
|
||||
|
|
@ -40,6 +41,7 @@ router.post(
|
|||
loginLimiter,
|
||||
checkBan,
|
||||
ldapAuth ? requireLdapAuth : requireLocalAuth,
|
||||
setBalanceConfig,
|
||||
loginController,
|
||||
);
|
||||
router.post('/refresh', refreshController);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const router = express.Router();
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
|
|
@ -12,6 +13,7 @@ const { initializeClient } = require('~/server/services/Endpoints/bedrock');
|
|||
const AgentController = require('~/server/controllers/agents/request');
|
||||
const addTitle = require('~/server/services/Endpoints/agents/title');
|
||||
|
||||
router.use(moderateText);
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,19 +1,35 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
// concurrentLimiter,
|
||||
// messageIpLimiter,
|
||||
// messageUserLimiter,
|
||||
messageIpLimiter,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const chat = require('./chat');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
router.use(concurrentLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
||||
router.use(messageIpLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||
router.use(messageUserLimiter);
|
||||
}
|
||||
|
||||
router.use('/chat', chat);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ const fs = require('fs').promises;
|
|||
const express = require('express');
|
||||
const { EnvVar } = require('@librechat/agents');
|
||||
const {
|
||||
Time,
|
||||
isUUID,
|
||||
CacheKeys,
|
||||
FileSources,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
|
|
@ -17,8 +19,10 @@ const {
|
|||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
|
||||
const { getFiles, batchUpdateFiles } = require('~/models/File');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
|
|
@ -26,6 +30,18 @@ const router = express.Router();
|
|||
router.get('/', async (req, res) => {
|
||||
try {
|
||||
const files = await getFiles({ user: req.user.id });
|
||||
if (req.app.locals.fileStrategy === FileSources.s3) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.S3_EXPIRY_INTERVAL);
|
||||
const alreadyChecked = await cache.get(req.user.id);
|
||||
if (!alreadyChecked) {
|
||||
await refreshS3FileUrls(files, batchUpdateFiles);
|
||||
await cache.set(req.user.id, true, Time.THIRTY_MINUTES);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('[/files] Error refreshing S3 file URLs:', error);
|
||||
}
|
||||
}
|
||||
res.status(200).send(files);
|
||||
} catch (error) {
|
||||
logger.error('[/files] Error getting files:', error);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
|
||||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const { loginLimiter, logHeaders, checkBan, checkDomainAllowed } = require('~/server/middleware');
|
||||
const {
|
||||
checkBan,
|
||||
logHeaders,
|
||||
loginLimiter,
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
} = require('~/server/middleware');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -56,6 +62,7 @@ router.get(
|
|||
session: false,
|
||||
scope: ['openid', 'profile', 'email'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -80,6 +87,7 @@ router.get(
|
|||
scope: ['public_profile'],
|
||||
profileFields: ['id', 'email', 'name'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -100,6 +108,7 @@ router.get(
|
|||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -122,6 +131,7 @@ router.get(
|
|||
session: false,
|
||||
scope: ['user:email', 'read:user'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -144,6 +154,7 @@ router.get(
|
|||
session: false,
|
||||
scope: ['identify', 'email'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -164,6 +175,7 @@ router.post(
|
|||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ const {
|
|||
actionDomainSeparator,
|
||||
} = require('librechat-data-provider');
|
||||
const { refreshAccessToken } = require('~/server/services/TokenService');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { logger, getFlowStateManager, sendEvent } = require('~/config');
|
||||
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
|
||||
const { getActions, deleteActions } = require('~/models/Action');
|
||||
|
|
@ -130,6 +129,7 @@ async function loadActionSets(searchParams) {
|
|||
* @param {string | undefined} [params.name] - The name of the tool.
|
||||
* @param {string | undefined} [params.description] - The description for the tool.
|
||||
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
|
||||
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createActionTool({
|
||||
|
|
@ -140,17 +140,8 @@ async function createActionTool({
|
|||
zodSchema,
|
||||
name,
|
||||
description,
|
||||
encrypted,
|
||||
}) {
|
||||
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
return null;
|
||||
}
|
||||
const encrypted = {
|
||||
oauth_client_id: action.metadata.oauth_client_id,
|
||||
oauth_client_secret: action.metadata.oauth_client_secret,
|
||||
};
|
||||
action.metadata = await decryptMetadata(action.metadata);
|
||||
|
||||
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolInput, config) => {
|
||||
try {
|
||||
|
|
@ -308,9 +299,8 @@ async function createActionTool({
|
|||
}
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
const logMessage = `API call to ${action.metadata.domain} failed`;
|
||||
logAxiosError({ message: logMessage, error });
|
||||
throw error;
|
||||
const message = `API call to ${action.metadata.domain} failed:`;
|
||||
return logAxiosError({ message, error });
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -327,6 +317,27 @@ async function createActionTool({
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypts a sensitive value.
|
||||
* @param {string} value
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
async function encryptSensitiveValue(value) {
|
||||
// Encode API key to handle special characters like ":"
|
||||
const encodedValue = encodeURIComponent(value);
|
||||
return await encryptV2(encodedValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypts a sensitive value.
|
||||
* @param {string} value
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
async function decryptSensitiveValue(value) {
|
||||
const decryptedValue = await decryptV2(value);
|
||||
return decodeURIComponent(decryptedValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypts sensitive metadata values for an action.
|
||||
*
|
||||
|
|
@ -339,17 +350,19 @@ async function encryptMetadata(metadata) {
|
|||
// ServiceHttp
|
||||
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
|
||||
if (metadata.api_key) {
|
||||
encryptedMetadata.api_key = await encryptV2(metadata.api_key);
|
||||
encryptedMetadata.api_key = await encryptSensitiveValue(metadata.api_key);
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth
|
||||
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
|
||||
if (metadata.oauth_client_id) {
|
||||
encryptedMetadata.oauth_client_id = await encryptV2(metadata.oauth_client_id);
|
||||
encryptedMetadata.oauth_client_id = await encryptSensitiveValue(metadata.oauth_client_id);
|
||||
}
|
||||
if (metadata.oauth_client_secret) {
|
||||
encryptedMetadata.oauth_client_secret = await encryptV2(metadata.oauth_client_secret);
|
||||
encryptedMetadata.oauth_client_secret = await encryptSensitiveValue(
|
||||
metadata.oauth_client_secret,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -368,17 +381,19 @@ async function decryptMetadata(metadata) {
|
|||
// ServiceHttp
|
||||
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
|
||||
if (metadata.api_key) {
|
||||
decryptedMetadata.api_key = await decryptV2(metadata.api_key);
|
||||
decryptedMetadata.api_key = await decryptSensitiveValue(metadata.api_key);
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth
|
||||
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
|
||||
if (metadata.oauth_client_id) {
|
||||
decryptedMetadata.oauth_client_id = await decryptV2(metadata.oauth_client_id);
|
||||
decryptedMetadata.oauth_client_id = await decryptSensitiveValue(metadata.oauth_client_id);
|
||||
}
|
||||
if (metadata.oauth_client_secret) {
|
||||
decryptedMetadata.oauth_client_secret = await decryptV2(metadata.oauth_client_secret);
|
||||
decryptedMetadata.oauth_client_secret = await decryptSensitiveValue(
|
||||
metadata.oauth_client_secret,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ const AppService = async (app) => {
|
|||
|
||||
if (fileStrategy === FileSources.firebase) {
|
||||
initializeFirebase();
|
||||
} else if (fileStrategy === FileSources.azure) {
|
||||
} else if (fileStrategy === FileSources.azure_blob) {
|
||||
initializeAzureBlobService();
|
||||
} else if (fileStrategy === FileSources.s3) {
|
||||
initializeS3();
|
||||
|
|
@ -146,7 +146,7 @@ const AppService = async (app) => {
|
|||
...defaultLocals,
|
||||
fileConfig: config?.fileConfig,
|
||||
secureImageLinks: config?.secureImageLinks,
|
||||
modelSpecs: processModelSpecs(endpoints, config.modelSpecs),
|
||||
modelSpecs: processModelSpecs(endpoints, config.modelSpecs, interfaceConfig),
|
||||
...endpointLocals,
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ const sendVerificationEmail = async (user) => {
|
|||
subject: 'Verify your email',
|
||||
payload: {
|
||||
appName: process.env.APP_TITLE || 'LibreChat',
|
||||
name: user.name,
|
||||
name: user.name || user.username || user.email,
|
||||
verificationLink: verificationLink,
|
||||
year: new Date().getFullYear(),
|
||||
},
|
||||
|
|
@ -278,7 +278,7 @@ const requestPasswordReset = async (req) => {
|
|||
subject: 'Password Reset Request',
|
||||
payload: {
|
||||
appName: process.env.APP_TITLE || 'LibreChat',
|
||||
name: user.name,
|
||||
name: user.name || user.username || user.email,
|
||||
link: link,
|
||||
year: new Date().getFullYear(),
|
||||
},
|
||||
|
|
@ -331,7 +331,7 @@ const resetPassword = async (userId, token, password) => {
|
|||
subject: 'Password Reset Successfully',
|
||||
payload: {
|
||||
appName: process.env.APP_TITLE || 'LibreChat',
|
||||
name: user.name,
|
||||
name: user.name || user.username || user.email,
|
||||
year: new Date().getFullYear(),
|
||||
},
|
||||
template: 'passwordReset.handlebars',
|
||||
|
|
@ -414,7 +414,7 @@ const resendVerificationEmail = async (req) => {
|
|||
subject: 'Verify your email',
|
||||
payload: {
|
||||
appName: process.env.APP_TITLE || 'LibreChat',
|
||||
name: user.name,
|
||||
name: user.name || user.username || user.email,
|
||||
verificationLink: verificationLink,
|
||||
year: new Date().getFullYear(),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -31,19 +31,16 @@ async function getCustomConfig() {
|
|||
async function getBalanceConfig() {
|
||||
const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE);
|
||||
const startBalance = process.env.START_BALANCE;
|
||||
if (isLegacyEnabled || (startBalance != null && startBalance)) {
|
||||
/** @type {TCustomConfig['balance']} */
|
||||
const config = {
|
||||
enabled: isLegacyEnabled,
|
||||
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
|
||||
};
|
||||
return config;
|
||||
}
|
||||
/** @type {TCustomConfig['balance']} */
|
||||
const config = {
|
||||
enabled: isLegacyEnabled,
|
||||
startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined,
|
||||
};
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!customConfig) {
|
||||
return null;
|
||||
return config;
|
||||
}
|
||||
return customConfig?.['balance'] ?? null;
|
||||
return { ...config, ...(customConfig?.['balance'] ?? {}) };
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -33,10 +33,12 @@ async function getEndpointsConfig(req) {
|
|||
};
|
||||
}
|
||||
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
|
||||
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];
|
||||
const { disableBuilder, capabilities, allowedProviders, ..._rest } =
|
||||
req.app.locals[EModelEndpoint.agents];
|
||||
|
||||
mergedConfig[EModelEndpoint.agents] = {
|
||||
...mergedConfig[EModelEndpoint.agents],
|
||||
allowedProviders,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const { createContentAggregator, Providers } = require('@librechat/agents');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
getResponseSender,
|
||||
AgentCapabilities,
|
||||
|
|
@ -117,6 +118,7 @@ function optionalChainWithEmptyCheck(...values) {
|
|||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Agent} params.agent
|
||||
* @param {Set<string>} [params.allowedProviders]
|
||||
* @param {object} [params.endpointOption]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent>}
|
||||
|
|
@ -126,8 +128,14 @@ const initializeAgentOptions = async ({
|
|||
res,
|
||||
agent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
|
||||
throw new Error(
|
||||
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
|
||||
);
|
||||
}
|
||||
let currentFiles;
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
|
|
@ -263,6 +271,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
}
|
||||
|
||||
const agentConfigs = new Map();
|
||||
/** @type {Set<string>} */
|
||||
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
||||
// Handle primary agent
|
||||
const primaryConfig = await initializeAgentOptions({
|
||||
|
|
@ -270,6 +280,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
res,
|
||||
agent: primaryAgent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
isInitialAgent: true,
|
||||
});
|
||||
|
||||
|
|
@ -285,6 +296,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
res,
|
||||
agent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
});
|
||||
agentConfigs.set(agentId, config);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
const { CacheKeys, findLastSeparatorIndex, SEPARATORS, Time } = require('librechat-data-provider');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
SEPARATORS,
|
||||
parseTextParts,
|
||||
findLastSeparatorIndex,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMessage } = require('~/models/Message');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
|
|
@ -84,10 +90,11 @@ function createChunkProcessor(user, messageId) {
|
|||
notFoundCount++;
|
||||
return [];
|
||||
} else {
|
||||
const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text;
|
||||
messageCache.set(
|
||||
messageId,
|
||||
{
|
||||
text: message.text,
|
||||
text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
|
|
@ -95,7 +102,7 @@ function createChunkProcessor(user, messageId) {
|
|||
}
|
||||
|
||||
const text = typeof message === 'string' ? message : message.text;
|
||||
const complete = typeof message === 'string' ? false : message.complete ?? true;
|
||||
const complete = typeof message === 'string' ? false : (message.complete ?? true);
|
||||
|
||||
if (text === processedText) {
|
||||
noChangeCount++;
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const mime = require('mime');
|
||||
const axios = require('axios');
|
||||
const fetch = require('node-fetch');
|
||||
const { logger } = require('~/config');
|
||||
const { getAzureContainerClient } = require('./initialize');
|
||||
|
||||
const defaultBasePath = 'images';
|
||||
const { AZURE_STORAGE_PUBLIC_ACCESS = 'true', AZURE_CONTAINER_NAME = 'files' } = process.env;
|
||||
|
||||
/**
|
||||
* Uploads a buffer to Azure Blob Storage.
|
||||
|
|
@ -29,10 +31,9 @@ async function saveBufferToAzure({
|
|||
}) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
|
||||
// Create the container if it doesn't exist. This is done per operation.
|
||||
await containerClient.createIfNotExists({
|
||||
access: process.env.AZURE_STORAGE_PUBLIC_ACCESS ? 'blob' : undefined,
|
||||
});
|
||||
await containerClient.createIfNotExists({ access });
|
||||
const blobPath = `${basePath}/${userId}/${fileName}`;
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
await blockBlobClient.uploadData(buffer);
|
||||
|
|
@ -97,25 +98,21 @@ async function getAzureURL({ fileName, basePath = defaultBasePath, userId, conta
|
|||
* Deletes a blob from Azure Blob Storage.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.fileName - The name of the file.
|
||||
* @param {string} [params.basePath='images'] - The base folder where the file is stored.
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @param {ServerRequest} params.req - The Express request object.
|
||||
* @param {MongoFile} params.file - The file object.
|
||||
*/
|
||||
async function deleteFileFromAzure({
|
||||
fileName,
|
||||
basePath = defaultBasePath,
|
||||
userId,
|
||||
containerName,
|
||||
}) {
|
||||
async function deleteFileFromAzure(req, file) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
const blobPath = `${basePath}/${userId}/${fileName}`;
|
||||
const containerClient = getAzureContainerClient(AZURE_CONTAINER_NAME);
|
||||
const blobPath = file.filepath.split(`${AZURE_CONTAINER_NAME}/`)[1];
|
||||
if (!blobPath.includes(req.user.id)) {
|
||||
throw new Error('User ID not found in blob path');
|
||||
}
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
await blockBlobClient.delete();
|
||||
logger.debug('[deleteFileFromAzure] Blob deleted successfully from Azure Blob Storage');
|
||||
} catch (error) {
|
||||
logger.error('[deleteFileFromAzure] Error deleting blob:', error.message);
|
||||
logger.error('[deleteFileFromAzure] Error deleting blob:', error);
|
||||
if (error.statusCode === 404) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -123,6 +120,65 @@ async function deleteFileFromAzure({
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Streams a file from disk directly to Azure Blob Storage without loading
|
||||
* the entire file into memory.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} params.filePath - The local file path to upload.
|
||||
* @param {string} params.fileName - The name of the file in Azure.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<string>} The URL of the uploaded blob.
|
||||
*/
|
||||
async function streamFileToAzure({
|
||||
userId,
|
||||
filePath,
|
||||
fileName,
|
||||
basePath = defaultBasePath,
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
|
||||
|
||||
// Create the container if it doesn't exist
|
||||
await containerClient.createIfNotExists({ access });
|
||||
|
||||
const blobPath = `${basePath}/${userId}/${fileName}`;
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
|
||||
// Get file size for proper content length
|
||||
const stats = await fs.promises.stat(filePath);
|
||||
|
||||
// Create read stream from the file
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
|
||||
const blobContentType = mime.getType(fileName);
|
||||
await blockBlobClient.uploadStream(
|
||||
fileStream,
|
||||
undefined, // Use default concurrency (5)
|
||||
undefined, // Use default buffer size (8MB)
|
||||
{
|
||||
blobHTTPHeaders: {
|
||||
blobContentType,
|
||||
},
|
||||
onProgress: (progress) => {
|
||||
logger.debug(
|
||||
`[streamFileToAzure] Upload progress: ${progress.loadedBytes} bytes of ${stats.size}`,
|
||||
);
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
return blockBlobClient.url;
|
||||
} catch (error) {
|
||||
logger.error('[streamFileToAzure] Error streaming file:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a file from the local file system to Azure Blob Storage.
|
||||
*
|
||||
|
|
@ -146,18 +202,19 @@ async function uploadFileToAzure({
|
|||
}) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const bytes = Buffer.byteLength(inputBuffer);
|
||||
const stats = await fs.promises.stat(inputFilePath);
|
||||
const bytes = stats.size;
|
||||
const userId = req.user.id;
|
||||
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const fileURL = await saveBufferToAzure({
|
||||
|
||||
const fileURL = await streamFileToAzure({
|
||||
userId,
|
||||
buffer: inputBuffer,
|
||||
filePath: inputFilePath,
|
||||
fileName,
|
||||
basePath,
|
||||
containerName,
|
||||
});
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
|
||||
return { filepath: fileURL, bytes };
|
||||
} catch (error) {
|
||||
logger.error('[uploadFileToAzure] Error uploading file:', error);
|
||||
|
|
|
|||
|
|
@ -32,11 +32,12 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) {
|
|||
const response = await axios(options);
|
||||
return response;
|
||||
} catch (error) {
|
||||
logAxiosError({
|
||||
message: `Error downloading code environment file stream: ${error.message}`,
|
||||
error,
|
||||
});
|
||||
throw new Error(`Error downloading file: ${error.message}`);
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message: `Error downloading code environment file stream: ${error.message}`,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,11 +90,12 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
|
|||
|
||||
return `${fileIdentifier}?entity_id=${entity_id}`;
|
||||
} catch (error) {
|
||||
logAxiosError({
|
||||
message: `Error uploading code environment file: ${error.message}`,
|
||||
error,
|
||||
});
|
||||
throw new Error(`Error uploading code environment file: ${error.message}`);
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message: `Error uploading code environment file: ${error.message}`,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ const FormData = require('form-data');
|
|||
const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { logger, createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logAxiosError } = require('~/utils/axios');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
|
|
@ -194,8 +194,7 @@ const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
|
|||
};
|
||||
} catch (error) {
|
||||
const message = 'Error uploading document to Mistral OCR API';
|
||||
logAxiosError({ error, message });
|
||||
throw new Error(message);
|
||||
throw new Error(logAxiosError({ error, message }));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -29,9 +29,6 @@ const mockAxios = {
|
|||
|
||||
jest.mock('axios', () => mockAxios);
|
||||
jest.mock('fs');
|
||||
jest.mock('~/utils', () => ({
|
||||
logAxiosError: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
|
|
@ -494,9 +491,6 @@ describe('MistralOCR Service', () => {
|
|||
}),
|
||||
).rejects.toThrow('Error uploading document to Mistral OCR API');
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
const { logAxiosError } = require('~/utils');
|
||||
expect(logAxiosError).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle single page documents without page numbering', async () => {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const fetch = require('node-fetch');
|
||||
const { PutObjectCommand, GetObjectCommand, DeleteObjectCommand } = require('@aws-sdk/client-s3');
|
||||
const { FileSources } = require('librechat-data-provider');
|
||||
const {
|
||||
PutObjectCommand,
|
||||
GetObjectCommand,
|
||||
HeadObjectCommand,
|
||||
DeleteObjectCommand,
|
||||
} = require('@aws-sdk/client-s3');
|
||||
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
|
||||
const { initializeS3 } = require('./initialize');
|
||||
const { logger } = require('~/config');
|
||||
|
|
@ -9,6 +15,34 @@ const { logger } = require('~/config');
|
|||
const bucketName = process.env.AWS_BUCKET_NAME;
|
||||
const defaultBasePath = 'images';
|
||||
|
||||
let s3UrlExpirySeconds = 7 * 24 * 60 * 60;
|
||||
let s3RefreshExpiryMs = null;
|
||||
|
||||
if (process.env.S3_URL_EXPIRY_SECONDS !== undefined) {
|
||||
const parsed = parseInt(process.env.S3_URL_EXPIRY_SECONDS, 10);
|
||||
|
||||
if (!isNaN(parsed) && parsed > 0) {
|
||||
s3UrlExpirySeconds = Math.min(parsed, 7 * 24 * 60 * 60);
|
||||
} else {
|
||||
logger.warn(
|
||||
`[S3] Invalid S3_URL_EXPIRY_SECONDS value: "${process.env.S3_URL_EXPIRY_SECONDS}". Using 7-day expiry.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (process.env.S3_REFRESH_EXPIRY_MS !== null && process.env.S3_REFRESH_EXPIRY_MS) {
|
||||
const parsed = parseInt(process.env.S3_REFRESH_EXPIRY_MS, 10);
|
||||
|
||||
if (!isNaN(parsed) && parsed > 0) {
|
||||
s3RefreshExpiryMs = parsed;
|
||||
logger.info(`[S3] Using custom refresh expiry time: ${s3RefreshExpiryMs}ms`);
|
||||
} else {
|
||||
logger.warn(
|
||||
`[S3] Invalid S3_REFRESH_EXPIRY_MS value: "${process.env.S3_REFRESH_EXPIRY_MS}". Using default refresh logic.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs the S3 key based on the base path, user ID, and file name.
|
||||
*/
|
||||
|
|
@ -39,13 +73,14 @@ async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBase
|
|||
}
|
||||
|
||||
/**
|
||||
* Retrieves a signed URL for a file stored in S3.
|
||||
* Retrieves a URL for a file stored in S3.
|
||||
* Returns a signed URL with expiration time or a proxy URL based on config
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.fileName - The file name in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<string>} A signed URL valid for 24 hours.
|
||||
* @returns {Promise<string>} A URL to access the S3 object
|
||||
*/
|
||||
async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
|
|
@ -53,7 +88,7 @@ async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
|
|||
|
||||
try {
|
||||
const s3 = initializeS3();
|
||||
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: 86400 });
|
||||
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: s3UrlExpirySeconds });
|
||||
} catch (error) {
|
||||
logger.error('[getS3URL] Error getting signed URL from S3:', error.message);
|
||||
throw error;
|
||||
|
|
@ -86,21 +121,51 @@ async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }
|
|||
* Deletes a file from S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.fileName - The file name in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {MongoFile} params.file - The file object to delete.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
async function deleteFileFromS3(req, file) {
|
||||
const key = extractKeyFromS3Url(file.filepath);
|
||||
const params = { Bucket: bucketName, Key: key };
|
||||
if (!key.includes(req.user.id)) {
|
||||
const message = `[deleteFileFromS3] User ID mismatch: ${req.user.id} vs ${key}`;
|
||||
logger.error(message);
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
try {
|
||||
const s3 = initializeS3();
|
||||
await s3.send(new DeleteObjectCommand(params));
|
||||
logger.debug('[deleteFileFromS3] File deleted successfully from S3');
|
||||
|
||||
try {
|
||||
const headCommand = new HeadObjectCommand(params);
|
||||
await s3.send(headCommand);
|
||||
logger.debug('[deleteFileFromS3] File exists, proceeding with deletion');
|
||||
} catch (headErr) {
|
||||
if (headErr.name === 'NotFound') {
|
||||
logger.warn(`[deleteFileFromS3] File does not exist: ${key}`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const deleteResult = await s3.send(new DeleteObjectCommand(params));
|
||||
logger.debug('[deleteFileFromS3] Delete command response:', JSON.stringify(deleteResult));
|
||||
try {
|
||||
await s3.send(new HeadObjectCommand(params));
|
||||
logger.error('[deleteFileFromS3] File still exists after deletion!');
|
||||
} catch (verifyErr) {
|
||||
if (verifyErr.name === 'NotFound') {
|
||||
logger.debug(`[deleteFileFromS3] Verified file is deleted: ${key}`);
|
||||
} else {
|
||||
logger.error('[deleteFileFromS3] Error verifying deletion:', verifyErr);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('[deleteFileFromS3] S3 File deletion completed');
|
||||
} catch (error) {
|
||||
logger.error('[deleteFileFromS3] Error deleting file from S3:', error.message);
|
||||
logger.error(`[deleteFileFromS3] Error deleting file from S3: ${error.message}`);
|
||||
logger.error(error.stack);
|
||||
|
||||
// If the file is not found, we can safely return.
|
||||
if (error.code === 'NoSuchKey') {
|
||||
return;
|
||||
|
|
@ -110,7 +175,7 @@ async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }
|
|||
}
|
||||
|
||||
/**
|
||||
* Uploads a local file to S3.
|
||||
* Uploads a local file to S3 by streaming it directly without loading into memory.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req - The Express request (must include user).
|
||||
|
|
@ -122,37 +187,272 @@ async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }
|
|||
async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const bytes = Buffer.byteLength(inputBuffer);
|
||||
const userId = req.user.id;
|
||||
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const fileURL = await saveBufferToS3({ userId, buffer: inputBuffer, fileName, basePath });
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
|
||||
const stats = await fs.promises.stat(inputFilePath);
|
||||
const bytes = stats.size;
|
||||
const fileStream = fs.createReadStream(inputFilePath);
|
||||
|
||||
const s3 = initializeS3();
|
||||
const uploadParams = {
|
||||
Bucket: bucketName,
|
||||
Key: key,
|
||||
Body: fileStream,
|
||||
};
|
||||
|
||||
await s3.send(new PutObjectCommand(uploadParams));
|
||||
const fileURL = await getS3URL({ userId, fileName, basePath });
|
||||
return { filepath: fileURL, bytes };
|
||||
} catch (error) {
|
||||
logger.error('[uploadFileToS3] Error uploading file to S3:', error.message);
|
||||
logger.error('[uploadFileToS3] Error streaming file to S3:', error);
|
||||
try {
|
||||
if (file && file.path) {
|
||||
await fs.promises.unlink(file.path);
|
||||
}
|
||||
} catch (unlinkError) {
|
||||
logger.error(
|
||||
'[uploadFileToS3] Error deleting temporary file, likely already deleted:',
|
||||
unlinkError.message,
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the S3 key from a URL or returns the key if already properly formatted
|
||||
*
|
||||
* @param {string} fileUrlOrKey - The file URL or key
|
||||
* @returns {string} The S3 key
|
||||
*/
|
||||
function extractKeyFromS3Url(fileUrlOrKey) {
|
||||
if (!fileUrlOrKey) {
|
||||
throw new Error('Invalid input: URL or key is empty');
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(fileUrlOrKey);
|
||||
return url.pathname.substring(1);
|
||||
} catch (error) {
|
||||
const parts = fileUrlOrKey.split('/');
|
||||
|
||||
if (parts.length >= 3 && !fileUrlOrKey.startsWith('http') && !fileUrlOrKey.startsWith('/')) {
|
||||
return fileUrlOrKey;
|
||||
}
|
||||
|
||||
return fileUrlOrKey.startsWith('/') ? fileUrlOrKey.substring(1) : fileUrlOrKey;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a readable stream for a file stored in S3.
|
||||
*
|
||||
* @param {ServerRequest} req - Server request object.
|
||||
* @param {string} filePath - The S3 key of the file.
|
||||
* @returns {Promise<NodeJS.ReadableStream>}
|
||||
*/
|
||||
async function getS3FileStream(filePath) {
|
||||
const params = { Bucket: bucketName, Key: filePath };
|
||||
async function getS3FileStream(_req, filePath) {
|
||||
try {
|
||||
const Key = extractKeyFromS3Url(filePath);
|
||||
const params = { Bucket: bucketName, Key };
|
||||
const s3 = initializeS3();
|
||||
const data = await s3.send(new GetObjectCommand(params));
|
||||
return data.Body; // Returns a Node.js ReadableStream.
|
||||
} catch (error) {
|
||||
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error.message);
|
||||
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if a signed S3 URL is close to expiration
|
||||
*
|
||||
* @param {string} signedUrl - The signed S3 URL
|
||||
* @param {number} bufferSeconds - Buffer time in seconds
|
||||
* @returns {boolean} True if the URL needs refreshing
|
||||
*/
|
||||
function needsRefresh(signedUrl, bufferSeconds) {
|
||||
try {
|
||||
// Parse the URL
|
||||
const url = new URL(signedUrl);
|
||||
|
||||
// Check if it has the signature parameters that indicate it's a signed URL
|
||||
// X-Amz-Signature is the most reliable indicator for AWS signed URLs
|
||||
if (!url.searchParams.has('X-Amz-Signature')) {
|
||||
// Not a signed URL, so no expiration to check (or it's already a proxy URL)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Extract the expiration time from the URL
|
||||
const expiresParam = url.searchParams.get('X-Amz-Expires');
|
||||
const dateParam = url.searchParams.get('X-Amz-Date');
|
||||
|
||||
if (!expiresParam || !dateParam) {
|
||||
// Missing expiration information, assume it needs refresh to be safe
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse the AWS date format (YYYYMMDDTHHMMSSZ)
|
||||
const year = dateParam.substring(0, 4);
|
||||
const month = dateParam.substring(4, 6);
|
||||
const day = dateParam.substring(6, 8);
|
||||
const hour = dateParam.substring(9, 11);
|
||||
const minute = dateParam.substring(11, 13);
|
||||
const second = dateParam.substring(13, 15);
|
||||
|
||||
const dateObj = new Date(`${year}-${month}-${day}T${hour}:${minute}:${second}Z`);
|
||||
const expiresAtDate = new Date(dateObj.getTime() + parseInt(expiresParam) * 1000);
|
||||
|
||||
// Check if it's close to expiration
|
||||
const now = new Date();
|
||||
|
||||
// If S3_REFRESH_EXPIRY_MS is set, use it to determine if URL is expired
|
||||
if (s3RefreshExpiryMs !== null) {
|
||||
const urlCreationTime = dateObj.getTime();
|
||||
const urlAge = now.getTime() - urlCreationTime;
|
||||
return urlAge >= s3RefreshExpiryMs;
|
||||
}
|
||||
|
||||
// Otherwise use the default buffer-based logic
|
||||
const bufferTime = new Date(now.getTime() + bufferSeconds * 1000);
|
||||
return expiresAtDate <= bufferTime;
|
||||
} catch (error) {
|
||||
logger.error('Error checking URL expiration:', error);
|
||||
// If we can't determine, assume it needs refresh to be safe
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a new URL for an expired S3 URL
|
||||
* @param {string} currentURL - The current file URL
|
||||
* @returns {Promise<string | undefined>}
|
||||
*/
|
||||
async function getNewS3URL(currentURL) {
|
||||
try {
|
||||
const s3Key = extractKeyFromS3Url(currentURL);
|
||||
if (!s3Key) {
|
||||
return;
|
||||
}
|
||||
const keyParts = s3Key.split('/');
|
||||
if (keyParts.length < 3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const basePath = keyParts[0];
|
||||
const userId = keyParts[1];
|
||||
const fileName = keyParts.slice(2).join('/');
|
||||
|
||||
return await getS3URL({
|
||||
userId,
|
||||
fileName,
|
||||
basePath,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error getting new S3 URL:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes S3 URLs for an array of files if they're expired or close to expiring
|
||||
*
|
||||
* @param {IMongoFile[]} files - Array of file documents
|
||||
* @param {(files: MongoFile[]) => Promise<void>} batchUpdateFiles - Function to update files in the database
|
||||
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
|
||||
* @returns {Promise<IMongoFile[]>} The files with refreshed URLs if needed
|
||||
*/
|
||||
async function refreshS3FileUrls(files, batchUpdateFiles, bufferSeconds = 3600) {
|
||||
if (!files || !Array.isArray(files) || files.length === 0) {
|
||||
return files;
|
||||
}
|
||||
|
||||
const filesToUpdate = [];
|
||||
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
const file = files[i];
|
||||
if (!file?.file_id) {
|
||||
continue;
|
||||
}
|
||||
if (file.source !== FileSources.s3) {
|
||||
continue;
|
||||
}
|
||||
if (!file.filepath) {
|
||||
continue;
|
||||
}
|
||||
if (!needsRefresh(file.filepath, bufferSeconds)) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
const newURL = await getNewS3URL(file.filepath);
|
||||
if (!newURL) {
|
||||
continue;
|
||||
}
|
||||
filesToUpdate.push({
|
||||
file_id: file.file_id,
|
||||
filepath: newURL,
|
||||
});
|
||||
files[i].filepath = newURL;
|
||||
} catch (error) {
|
||||
logger.error(`Error refreshing S3 URL for file ${file.file_id}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
if (filesToUpdate.length > 0) {
|
||||
await batchUpdateFiles(filesToUpdate);
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes a single S3 URL if it's expired or close to expiring
|
||||
*
|
||||
* @param {{ filepath: string, source: string }} fileObj - Simple file object containing filepath and source
|
||||
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
|
||||
* @returns {Promise<string>} The refreshed URL or the original URL if no refresh needed
|
||||
*/
|
||||
async function refreshS3Url(fileObj, bufferSeconds = 3600) {
|
||||
if (!fileObj || fileObj.source !== FileSources.s3 || !fileObj.filepath) {
|
||||
return fileObj?.filepath || '';
|
||||
}
|
||||
|
||||
if (!needsRefresh(fileObj.filepath, bufferSeconds)) {
|
||||
return fileObj.filepath;
|
||||
}
|
||||
|
||||
try {
|
||||
const s3Key = extractKeyFromS3Url(fileObj.filepath);
|
||||
if (!s3Key) {
|
||||
logger.warn(`Unable to extract S3 key from URL: ${fileObj.filepath}`);
|
||||
return fileObj.filepath;
|
||||
}
|
||||
|
||||
const keyParts = s3Key.split('/');
|
||||
if (keyParts.length < 3) {
|
||||
logger.warn(`Invalid S3 key format: ${s3Key}`);
|
||||
return fileObj.filepath;
|
||||
}
|
||||
|
||||
const basePath = keyParts[0];
|
||||
const userId = keyParts[1];
|
||||
const fileName = keyParts.slice(2).join('/');
|
||||
|
||||
const newUrl = await getS3URL({
|
||||
userId,
|
||||
fileName,
|
||||
basePath,
|
||||
});
|
||||
|
||||
logger.debug(`Refreshed S3 URL for key: ${s3Key}`);
|
||||
return newUrl;
|
||||
} catch (error) {
|
||||
logger.error(`Error refreshing S3 URL: ${error.message}`);
|
||||
return fileObj.filepath;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
saveBufferToS3,
|
||||
saveURLToS3,
|
||||
|
|
@ -160,4 +460,8 @@ module.exports = {
|
|||
deleteFileFromS3,
|
||||
uploadFileToS3,
|
||||
getS3FileStream,
|
||||
refreshS3FileUrls,
|
||||
refreshS3Url,
|
||||
needsRefresh,
|
||||
getNewS3URL,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ const {
|
|||
EModelEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
@ -24,8 +25,8 @@ async function fetchImageToBase64(url) {
|
|||
});
|
||||
return Buffer.from(response.data).toString('base64');
|
||||
} catch (error) {
|
||||
logger.error('Error fetching image to convert to base64', error);
|
||||
throw error;
|
||||
const message = 'Error fetching image to convert to base64';
|
||||
throw new Error(logAxiosError({ message, error }));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -37,17 +38,21 @@ const base64Only = new Set([
|
|||
EModelEndpoint.bedrock,
|
||||
]);
|
||||
|
||||
const blobStorageSources = new Set([FileSources.azure_blob, FileSources.s3]);
|
||||
|
||||
/**
|
||||
* Encodes and formats the given files.
|
||||
* @param {Express.Request} req - The request object.
|
||||
* @param {Array<MongoFile>} files - The array of files to encode and format.
|
||||
* @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image.
|
||||
* @param {string} [mode] - Optional: The endpoint mode for the image.
|
||||
* @returns {Promise<Object>} - A promise that resolves to the result object containing the encoded images and file details.
|
||||
* @returns {Promise<{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }>} - A promise that resolves to the result object containing the encoded images and file details.
|
||||
*/
|
||||
async function encodeAndFormat(req, files, endpoint, mode) {
|
||||
const promises = [];
|
||||
/** @type {Record<FileSources, Pick<ReturnType<typeof getStrategyFunctions>, 'prepareImagePayload' | 'getDownloadStream'>>} */
|
||||
const encodingMethods = {};
|
||||
/** @type {{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }} */
|
||||
const result = {
|
||||
text: '',
|
||||
files: [],
|
||||
|
|
@ -59,6 +64,7 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
}
|
||||
|
||||
for (let file of files) {
|
||||
/** @type {FileSources} */
|
||||
const source = file.source ?? FileSources.local;
|
||||
if (source === FileSources.text && file.text) {
|
||||
result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`;
|
||||
|
|
@ -70,18 +76,52 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
}
|
||||
|
||||
if (!encodingMethods[source]) {
|
||||
const { prepareImagePayload } = getStrategyFunctions(source);
|
||||
const { prepareImagePayload, getDownloadStream } = getStrategyFunctions(source);
|
||||
if (!prepareImagePayload) {
|
||||
throw new Error(`Encoding function not implemented for ${source}`);
|
||||
}
|
||||
|
||||
encodingMethods[source] = prepareImagePayload;
|
||||
encodingMethods[source] = { prepareImagePayload, getDownloadStream };
|
||||
}
|
||||
|
||||
const preparePayload = encodingMethods[source];
|
||||
const preparePayload = encodingMethods[source].prepareImagePayload;
|
||||
/* We need to fetch the image and convert it to base64 if we are using S3/Azure Blob storage. */
|
||||
if (blobStorageSources.has(source)) {
|
||||
try {
|
||||
const downloadStream = encodingMethods[source].getDownloadStream;
|
||||
const stream = await downloadStream(req, file.filepath);
|
||||
const streamPromise = new Promise((resolve, reject) => {
|
||||
/** @type {Uint8Array[]} */
|
||||
const chunks = [];
|
||||
stream.on('readable', () => {
|
||||
let chunk;
|
||||
while (null !== (chunk = stream.read())) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
});
|
||||
|
||||
/* Google & Anthropic don't support passing URLs to payload */
|
||||
if (source !== FileSources.local && base64Only.has(endpoint)) {
|
||||
stream.on('end', () => {
|
||||
const buffer = Buffer.concat(chunks);
|
||||
const base64Data = buffer.toString('base64');
|
||||
resolve(base64Data);
|
||||
});
|
||||
stream.on('error', (error) => {
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
const base64Data = await streamPromise;
|
||||
promises.push([file, base64Data]);
|
||||
continue;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error processing blob storage file stream for ${file.name} base64 payload:`,
|
||||
error,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
/* Google & Anthropic don't support passing URLs to payload */
|
||||
} else if (source !== FileSources.local && base64Only.has(endpoint)) {
|
||||
const [_file, imageURL] = await preparePayload(req, file);
|
||||
promises.push([_file, await fetchImageToBase64(imageURL)]);
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -492,7 +492,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
|
||||
let fileInfoMetadata;
|
||||
const entity_id = messageAttachment === true ? undefined : agent_id;
|
||||
|
||||
const basePath = mime.getType(file.originalname)?.startsWith('image') ? 'images' : 'uploads';
|
||||
if (tool_resource === EToolResources.execute_code) {
|
||||
const isCodeEnabled = await checkCapability(req, AgentCapabilities.execute_code);
|
||||
if (!isCodeEnabled) {
|
||||
|
|
@ -532,7 +532,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
images,
|
||||
filename,
|
||||
filepath: ocrFileURL,
|
||||
} = await handleFileUpload({ req, file, file_id, entity_id: agent_id });
|
||||
} = await handleFileUpload({ req, file, file_id, entity_id: agent_id, basePath });
|
||||
|
||||
const fileInfo = removeNullishValues({
|
||||
text,
|
||||
|
|
@ -582,6 +582,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
file,
|
||||
file_id,
|
||||
entity_id,
|
||||
basePath,
|
||||
});
|
||||
|
||||
let filepath = _filepath;
|
||||
|
|
|
|||
|
|
@ -211,6 +211,8 @@ const getStrategyFunctions = (fileSource) => {
|
|||
} else if (fileSource === FileSources.openai) {
|
||||
return openAIStrategy();
|
||||
} else if (fileSource === FileSources.azure) {
|
||||
return openAIStrategy();
|
||||
} else if (fileSource === FileSources.azure_blob) {
|
||||
return azureStrategy();
|
||||
} else if (fileSource === FileSources.vectordb) {
|
||||
return vectorStrategy();
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ const { logger, getMCPManager } = require('~/config');
|
|||
* Creates a general tool for an entire action set.
|
||||
*
|
||||
* @param {Object} params - The parameters for loading action sets.
|
||||
* @param {ServerRequest} params.req - The name of the tool.
|
||||
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||
* @param {string} params.toolKey - The toolKey for the tool.
|
||||
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {string} params.model - The model for the tool.
|
||||
|
|
@ -37,6 +37,15 @@ async function createMCPTool({ req, toolKey, provider }) {
|
|||
}
|
||||
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
const userId = req.user?.id;
|
||||
|
||||
if (!userId) {
|
||||
logger.error(
|
||||
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
|
||||
);
|
||||
throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`);
|
||||
}
|
||||
|
||||
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolArguments, config) => {
|
||||
try {
|
||||
|
|
@ -47,9 +56,11 @@ async function createMCPTool({ req, toolKey, provider }) {
|
|||
provider,
|
||||
toolArguments,
|
||||
options: {
|
||||
userId,
|
||||
signal: config?.signal,
|
||||
},
|
||||
});
|
||||
|
||||
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {
|
||||
return result[0];
|
||||
}
|
||||
|
|
@ -58,8 +69,13 @@ async function createMCPTool({ req, toolKey, provider }) {
|
|||
}
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error(`${toolName} MCP server tool call failed`, error);
|
||||
return `${toolName} MCP server tool call failed.`;
|
||||
logger.error(
|
||||
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`,
|
||||
error,
|
||||
);
|
||||
throw new Error(
|
||||
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -55,8 +55,7 @@ async function retrieveRun({ thread_id, run_id, timeout, openai }) {
|
|||
return response.data;
|
||||
} catch (error) {
|
||||
const message = '[retrieveRun] Failed to retrieve run data:';
|
||||
logAxiosError({ message, error });
|
||||
throw error;
|
||||
throw new Error(logAxiosError({ message, error }));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -93,11 +93,12 @@ const refreshAccessToken = async ({
|
|||
return response.data;
|
||||
} catch (error) {
|
||||
const message = 'Error refreshing OAuth tokens';
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
});
|
||||
throw new Error(message);
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -156,11 +157,12 @@ const getAccessToken = async ({
|
|||
return response.data;
|
||||
} catch (error) {
|
||||
const message = 'Error exchanging OAuth code';
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
});
|
||||
throw new Error(message);
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -15,9 +15,15 @@ const {
|
|||
AgentCapabilities,
|
||||
validateAndParseOpenAPISpec,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
loadActionSets,
|
||||
createActionTool,
|
||||
decryptMetadata,
|
||||
domainParser,
|
||||
} = require('./ActionService');
|
||||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { createYouTubeTools, manifestToolMap, toolkits } = require('~/app/clients/tools');
|
||||
const { loadActionSets, createActionTool, domainParser } = require('./ActionService');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
|
|
@ -315,58 +321,95 @@ async function processRequiredActions(client, requiredActions) {
|
|||
if (!tool) {
|
||||
// throw new Error(`Tool ${currentAction.tool} not found.`);
|
||||
|
||||
// Load all action sets once if not already loaded
|
||||
if (!actionSets.length) {
|
||||
actionSets =
|
||||
(await loadActionSets({
|
||||
assistant_id: client.req.body.assistant_id,
|
||||
})) ?? [];
|
||||
|
||||
// Process all action sets once
|
||||
// Map domains to their processed action sets
|
||||
const processedDomains = new Map();
|
||||
const domainMap = new Map();
|
||||
|
||||
for (const action of actionSets) {
|
||||
const domain = await domainParser(client.req, action.metadata.domain, true);
|
||||
domainMap.set(domain, action);
|
||||
|
||||
// Check if domain is allowed
|
||||
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate and parse OpenAPI spec
|
||||
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
|
||||
if (!validationResult.spec) {
|
||||
throw new Error(
|
||||
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Process the OpenAPI spec
|
||||
const { requestBuilders } = openapiToFunction(validationResult.spec);
|
||||
|
||||
// Store encrypted values for OAuth flow
|
||||
const encrypted = {
|
||||
oauth_client_id: action.metadata.oauth_client_id,
|
||||
oauth_client_secret: action.metadata.oauth_client_secret,
|
||||
};
|
||||
|
||||
// Decrypt metadata
|
||||
const decryptedAction = { ...action };
|
||||
decryptedAction.metadata = await decryptMetadata(action.metadata);
|
||||
|
||||
processedDomains.set(domain, {
|
||||
action: decryptedAction,
|
||||
requestBuilders,
|
||||
encrypted,
|
||||
});
|
||||
|
||||
// Store builders for reuse
|
||||
ActionBuildersMap[action.metadata.domain] = requestBuilders;
|
||||
}
|
||||
|
||||
// Update actionSets reference to use the domain map
|
||||
actionSets = { domainMap, processedDomains };
|
||||
}
|
||||
|
||||
let actionSet = null;
|
||||
// Find the matching domain for this tool
|
||||
let currentDomain = '';
|
||||
for (let action of actionSets) {
|
||||
const domain = await domainParser(client.req, action.metadata.domain, true);
|
||||
for (const domain of actionSets.domainMap.keys()) {
|
||||
if (currentAction.tool.includes(domain)) {
|
||||
currentDomain = domain;
|
||||
actionSet = action;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!actionSet) {
|
||||
if (!currentDomain || !actionSets.processedDomains.has(currentDomain)) {
|
||||
// TODO: try `function` if no action set is found
|
||||
// throw new Error(`Tool ${currentAction.tool} not found.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
let builders = ActionBuildersMap[actionSet.metadata.domain];
|
||||
|
||||
if (!builders) {
|
||||
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
|
||||
if (!validationResult.spec) {
|
||||
throw new Error(
|
||||
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
|
||||
);
|
||||
}
|
||||
const { requestBuilders } = openapiToFunction(validationResult.spec);
|
||||
ActionToolMap[actionSet.metadata.domain] = requestBuilders;
|
||||
builders = requestBuilders;
|
||||
}
|
||||
|
||||
const { action, requestBuilders, encrypted } = actionSets.processedDomains.get(currentDomain);
|
||||
const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, '');
|
||||
|
||||
const requestBuilder = builders[functionName];
|
||||
const requestBuilder = requestBuilders[functionName];
|
||||
|
||||
if (!requestBuilder) {
|
||||
// throw new Error(`Tool ${currentAction.tool} not found.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// We've already decrypted the metadata, so we can pass it directly
|
||||
tool = await createActionTool({
|
||||
req: client.req,
|
||||
res: client.res,
|
||||
action: actionSet,
|
||||
action,
|
||||
requestBuilder,
|
||||
// Note: intentionally not passing zodSchema, name, and description for assistants API
|
||||
encrypted, // Pass the encrypted values for OAuth flow
|
||||
});
|
||||
if (!tool) {
|
||||
logger.warn(
|
||||
|
|
@ -430,10 +473,10 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
|||
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
|
||||
|
||||
const _agentTools = agent.tools?.filter((tool) => {
|
||||
if (tool === Tools.file_search && !checkCapability(AgentCapabilities.file_search)) {
|
||||
return false;
|
||||
} else if (tool === Tools.execute_code && !checkCapability(AgentCapabilities.execute_code)) {
|
||||
return false;
|
||||
if (tool === Tools.file_search) {
|
||||
return checkCapability(AgentCapabilities.file_search);
|
||||
} else if (tool === Tools.execute_code) {
|
||||
return checkCapability(AgentCapabilities.execute_code);
|
||||
} else if (!areToolsEnabled && !tool.includes(actionDelimiter)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -511,7 +554,62 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
|||
};
|
||||
}
|
||||
|
||||
let actionSets = [];
|
||||
const actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
|
||||
if (actionSets.length === 0) {
|
||||
if (_agentTools.length > 0 && agentTools.length === 0) {
|
||||
logger.warn(`No tools found for the specified tool calls: ${_agentTools.join(', ')}`);
|
||||
}
|
||||
return {
|
||||
tools: agentTools,
|
||||
toolContextMap,
|
||||
};
|
||||
}
|
||||
|
||||
// Process each action set once (validate spec, decrypt metadata)
|
||||
const processedActionSets = new Map();
|
||||
const domainMap = new Map();
|
||||
|
||||
for (const action of actionSets) {
|
||||
const domain = await domainParser(req, action.metadata.domain, true);
|
||||
domainMap.set(domain, action);
|
||||
|
||||
// Check if domain is allowed (do this once per action set)
|
||||
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate and parse OpenAPI spec once per action set
|
||||
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
|
||||
if (!validationResult.spec) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const encrypted = {
|
||||
oauth_client_id: action.metadata.oauth_client_id,
|
||||
oauth_client_secret: action.metadata.oauth_client_secret,
|
||||
};
|
||||
|
||||
// Decrypt metadata once per action set
|
||||
const decryptedAction = { ...action };
|
||||
decryptedAction.metadata = await decryptMetadata(action.metadata);
|
||||
|
||||
// Process the OpenAPI spec once per action set
|
||||
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
|
||||
validationResult.spec,
|
||||
true,
|
||||
);
|
||||
|
||||
processedActionSets.set(domain, {
|
||||
action: decryptedAction,
|
||||
requestBuilders,
|
||||
functionSignatures,
|
||||
zodSchemas,
|
||||
encrypted,
|
||||
});
|
||||
}
|
||||
|
||||
// Now map tools to the processed action sets
|
||||
const ActionToolMap = {};
|
||||
|
||||
for (const toolName of _agentTools) {
|
||||
|
|
@ -519,55 +617,47 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!actionSets.length) {
|
||||
actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
|
||||
}
|
||||
|
||||
let actionSet = null;
|
||||
// Find the matching domain for this tool
|
||||
let currentDomain = '';
|
||||
for (let action of actionSets) {
|
||||
const domain = await domainParser(req, action.metadata.domain, true);
|
||||
for (const domain of domainMap.keys()) {
|
||||
if (toolName.includes(domain)) {
|
||||
currentDomain = domain;
|
||||
actionSet = action;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!actionSet) {
|
||||
if (!currentDomain || !processedActionSets.has(currentDomain)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
|
||||
if (validationResult.spec) {
|
||||
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
|
||||
validationResult.spec,
|
||||
true,
|
||||
);
|
||||
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
|
||||
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
|
||||
const requestBuilder = requestBuilders[functionName];
|
||||
const zodSchema = zodSchemas[functionName];
|
||||
const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } =
|
||||
processedActionSets.get(currentDomain);
|
||||
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
|
||||
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
|
||||
const requestBuilder = requestBuilders[functionName];
|
||||
const zodSchema = zodSchemas[functionName];
|
||||
|
||||
if (requestBuilder) {
|
||||
const tool = await createActionTool({
|
||||
req,
|
||||
res,
|
||||
action: actionSet,
|
||||
requestBuilder,
|
||||
zodSchema,
|
||||
name: toolName,
|
||||
description: functionSig.description,
|
||||
});
|
||||
if (!tool) {
|
||||
logger.warn(
|
||||
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
|
||||
);
|
||||
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
|
||||
}
|
||||
agentTools.push(tool);
|
||||
ActionToolMap[toolName] = tool;
|
||||
if (requestBuilder) {
|
||||
const tool = await createActionTool({
|
||||
req,
|
||||
res,
|
||||
action,
|
||||
requestBuilder,
|
||||
zodSchema,
|
||||
encrypted,
|
||||
name: toolName,
|
||||
description: functionSig.description,
|
||||
});
|
||||
|
||||
if (!tool) {
|
||||
logger.warn(
|
||||
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
|
||||
);
|
||||
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
|
||||
}
|
||||
|
||||
agentTools.push(tool);
|
||||
ActionToolMap[toolName] = tool;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,24 @@ const secretDefaults = {
|
|||
JWT_REFRESH_SECRET: 'eaa5191f2914e30b9387fd84e254e4ba6fc51b4654968a9b0803b456a54b8418',
|
||||
};
|
||||
|
||||
const deprecatedVariables = [
|
||||
{
|
||||
key: 'CHECK_BALANCE',
|
||||
description:
|
||||
'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview',
|
||||
},
|
||||
{
|
||||
key: 'START_BALANCE',
|
||||
description:
|
||||
'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview',
|
||||
},
|
||||
{
|
||||
key: 'GOOGLE_API_KEY',
|
||||
description:
|
||||
'Please use the `GOOGLE_SEARCH_API_KEY` environment variable for the Google Search Tool instead.',
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* Checks environment variables for default secrets and deprecated variables.
|
||||
* Logs warnings for any default secret values being used and for usage of deprecated `GOOGLE_API_KEY`.
|
||||
|
|
@ -37,19 +55,11 @@ function checkVariables() {
|
|||
\u200B`);
|
||||
}
|
||||
|
||||
if (process.env.GOOGLE_API_KEY) {
|
||||
logger.warn(
|
||||
'The `GOOGLE_API_KEY` environment variable is deprecated.\nPlease use the `GOOGLE_SEARCH_API_KEY` environment variable instead.',
|
||||
);
|
||||
}
|
||||
|
||||
if (process.env.OPENROUTER_API_KEY) {
|
||||
logger.warn(
|
||||
`The \`OPENROUTER_API_KEY\` environment variable is deprecated and its functionality will be removed soon.
|
||||
Use of this environment variable is highly discouraged as it can lead to unexpected errors when using custom endpoints.
|
||||
Please use the config (\`librechat.yaml\`) file for setting up OpenRouter, and use \`OPENROUTER_KEY\` or another environment variable instead.`,
|
||||
);
|
||||
}
|
||||
deprecatedVariables.forEach(({ key, description }) => {
|
||||
if (process.env[key]) {
|
||||
logger.warn(`The \`${key}\` environment variable is deprecated. ${description}`);
|
||||
}
|
||||
});
|
||||
|
||||
checkPasswordReset();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,12 +18,15 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
const { interface: interfaceConfig } = config ?? {};
|
||||
const { interface: defaults } = configDefaults;
|
||||
const hasModelSpecs = config?.modelSpecs?.list?.length > 0;
|
||||
const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0;
|
||||
|
||||
/** @type {TCustomConfig['interface']} */
|
||||
const loadedInterface = removeNullishValues({
|
||||
endpointsMenu:
|
||||
interfaceConfig?.endpointsMenu ?? (hasModelSpecs ? false : defaults.endpointsMenu),
|
||||
modelSelect: interfaceConfig?.modelSelect ?? (hasModelSpecs ? false : defaults.modelSelect),
|
||||
modelSelect:
|
||||
interfaceConfig?.modelSelect ??
|
||||
(hasModelSpecs ? includesAddedEndpoints : defaults.modelSelect),
|
||||
parameters: interfaceConfig?.parameters ?? (hasModelSpecs ? false : defaults.parameters),
|
||||
presets: interfaceConfig?.presets ?? (hasModelSpecs ? false : defaults.presets),
|
||||
sidePanel: interfaceConfig?.sidePanel ?? defaults.sidePanel,
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ const { logger } = require('~/config');
|
|||
* Sets up Model Specs from the config (`librechat.yaml`) file.
|
||||
* @param {TCustomConfig['endpoints']} [endpoints] - The loaded custom configuration for endpoints.
|
||||
* @param {TCustomConfig['modelSpecs'] | undefined} [modelSpecs] - The loaded custom configuration for model specs.
|
||||
* @param {TCustomConfig['interface'] | undefined} [interfaceConfig] - The loaded interface configuration.
|
||||
* @returns {TCustomConfig['modelSpecs'] | undefined} The processed model specs, if any.
|
||||
*/
|
||||
function processModelSpecs(endpoints, _modelSpecs) {
|
||||
function processModelSpecs(endpoints, _modelSpecs, interfaceConfig) {
|
||||
if (!_modelSpecs) {
|
||||
return undefined;
|
||||
}
|
||||
|
|
@ -20,6 +21,19 @@ function processModelSpecs(endpoints, _modelSpecs) {
|
|||
|
||||
const customEndpoints = endpoints?.[EModelEndpoint.custom] ?? [];
|
||||
|
||||
if (interfaceConfig.modelSelect !== true && (_modelSpecs.addedEndpoints?.length ?? 0) > 0) {
|
||||
logger.warn(
|
||||
`To utilize \`addedEndpoints\`, which allows provider/model selections alongside model specs, set \`modelSelect: true\` in the interface configuration.
|
||||
|
||||
Example:
|
||||
\`\`\`yaml
|
||||
interface:
|
||||
modelSelect: true
|
||||
\`\`\`
|
||||
`,
|
||||
);
|
||||
}
|
||||
|
||||
for (const spec of list) {
|
||||
if (EModelEndpoint[spec.preset.endpoint] && spec.preset.endpoint !== EModelEndpoint.custom) {
|
||||
modelSpecs.push(spec);
|
||||
|
|
|
|||
|
|
@ -403,6 +403,12 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports MessageContentImageUrl
|
||||
* @typedef {import('librechat-data-provider').Agents.MessageContentImageUrl} MessageContentImageUrl
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/** Prompts */
|
||||
/**
|
||||
* @exports TPrompt
|
||||
|
|
@ -760,6 +766,23 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports MongoFile
|
||||
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
|
||||
* @memberof typedefs
|
||||
*/
|
||||
/**
|
||||
* @exports IBalance
|
||||
* @typedef {import('@librechat/data-schemas').IBalance} IBalance
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports MongoUser
|
||||
* @typedef {import('@librechat/data-schemas').IUser} MongoUser
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports ObjectId
|
||||
* @typedef {import('mongoose').Types.ObjectId} ObjectId
|
||||
|
|
@ -805,6 +828,12 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TEndpointOption
|
||||
* @typedef {import('librechat-data-provider').TEndpointOption} TEndpointOption
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TAttachment
|
||||
* @typedef {import('librechat-data-provider').TAttachment} TAttachment
|
||||
|
|
|
|||
|
|
@ -6,32 +6,41 @@ const { logger } = require('~/config');
|
|||
* @param {Object} options - The options object.
|
||||
* @param {string} options.message - The custom message to be logged.
|
||||
* @param {import('axios').AxiosError} options.error - The Axios error object.
|
||||
* @returns {string} The log message.
|
||||
*/
|
||||
const logAxiosError = ({ message, error }) => {
|
||||
let logMessage = message;
|
||||
try {
|
||||
const stack = error.stack || 'No stack trace available';
|
||||
|
||||
if (error.response?.status) {
|
||||
const { status, headers, data } = error.response;
|
||||
logger.error(`${message} The server responded with status ${status}: ${error.message}`, {
|
||||
logMessage = `${message} The server responded with status ${status}: ${error.message}`;
|
||||
logger.error(logMessage, {
|
||||
status,
|
||||
headers,
|
||||
data,
|
||||
stack,
|
||||
});
|
||||
} else if (error.request) {
|
||||
const { method, url } = error.config || {};
|
||||
logger.error(
|
||||
`${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`,
|
||||
{ requestInfo: { method, url } },
|
||||
);
|
||||
logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`;
|
||||
logger.error(logMessage, {
|
||||
requestInfo: { method, url },
|
||||
stack,
|
||||
});
|
||||
} else if (error?.message?.includes('Cannot read properties of undefined (reading \'status\')')) {
|
||||
logger.error(
|
||||
`${message} It appears the request timed out or was unsuccessful: ${error.message}`,
|
||||
);
|
||||
logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`;
|
||||
logger.error(logMessage, { stack });
|
||||
} else {
|
||||
logger.error(`${message} An error occurred while setting up the request: ${error.message}`);
|
||||
logMessage = `${message} An error occurred while setting up the request: ${error.message}`;
|
||||
logger.error(logMessage, { stack });
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(`Error in logAxiosError: ${err.message}`);
|
||||
logMessage = `Error in logAxiosError: ${err.message}`;
|
||||
logger.error(logMessage, { stack: err.stack || 'No stack trace available' });
|
||||
}
|
||||
return logMessage;
|
||||
};
|
||||
|
||||
module.exports = { logAxiosError };
|
||||
|
|
|
|||
|
|
@ -34,8 +34,14 @@ const mistralModels = {
|
|||
'mistral-7b': 31990, // -10 from max
|
||||
'mistral-small': 31990, // -10 from max
|
||||
'mixtral-8x7b': 31990, // -10 from max
|
||||
'mistral-large': 131000,
|
||||
'mistral-large-2402': 127500,
|
||||
'mistral-large-2407': 127500,
|
||||
'pixtral-large': 131000,
|
||||
'mistral-saba': 32000,
|
||||
codestral: 256000,
|
||||
'ministral-8b': 131000,
|
||||
'ministral-3b': 131000,
|
||||
};
|
||||
|
||||
const cohereModels = {
|
||||
|
|
@ -52,6 +58,7 @@ const googleModels = {
|
|||
gemini: 30720, // -2048 from max
|
||||
'gemini-pro-vision': 12288,
|
||||
'gemini-exp': 2000000,
|
||||
'gemini-2.5': 1000000, // 1M input tokens, 64k output tokens
|
||||
'gemini-2.0': 2000000,
|
||||
'gemini-2.0-flash': 1000000,
|
||||
'gemini-2.0-flash-lite': 1000000,
|
||||
|
|
|
|||
BIN
bun.lockb
BIN
bun.lockb
Binary file not shown.
|
|
@ -31,8 +31,8 @@
|
|||
"@ariakit/react": "^0.4.15",
|
||||
"@ariakit/react-core": "^0.4.15",
|
||||
"@codesandbox/sandpack-react": "^2.19.10",
|
||||
"@dicebear/collection": "^7.0.4",
|
||||
"@dicebear/core": "^7.0.4",
|
||||
"@dicebear/collection": "^9.2.2",
|
||||
"@dicebear/core": "^9.2.2",
|
||||
"@headlessui/react": "^2.1.2",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-alert-dialog": "^1.0.2",
|
||||
|
|
@ -52,6 +52,7 @@
|
|||
"@radix-ui/react-switch": "^1.0.3",
|
||||
"@radix-ui/react-tabs": "^1.0.3",
|
||||
"@radix-ui/react-toast": "^1.1.5",
|
||||
"@react-spring/web": "^9.7.5",
|
||||
"@tanstack/react-query": "^4.28.0",
|
||||
"@tanstack/react-table": "^8.11.7",
|
||||
"class-variance-authority": "^0.6.0",
|
||||
|
|
@ -72,7 +73,7 @@
|
|||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.394.0",
|
||||
"match-sorter": "^6.3.4",
|
||||
"msedge-tts": "^1.3.4",
|
||||
"msedge-tts": "^2.0.0",
|
||||
"qrcode.react": "^4.2.0",
|
||||
"rc-input-number": "^7.4.2",
|
||||
"react": "^18.2.0",
|
||||
|
|
@ -121,7 +122,7 @@
|
|||
"@types/node": "^20.3.0",
|
||||
"@types/react": "^18.2.11",
|
||||
"@types/react-dom": "^18.2.4",
|
||||
"@vitejs/plugin-react": "^4.2.1",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"babel-plugin-replace-ts-export-assignment": "^0.0.2",
|
||||
"babel-plugin-root-import": "^6.6.0",
|
||||
|
|
@ -140,9 +141,9 @@
|
|||
"tailwindcss": "^3.4.1",
|
||||
"ts-jest": "^29.2.5",
|
||||
"typescript": "^5.3.3",
|
||||
"vite": "^6.1.0",
|
||||
"vite-plugin-node-polyfills": "^0.17.0",
|
||||
"vite-plugin-compression": "^0.5.1",
|
||||
"vite-plugin-pwa": "^0.21.1"
|
||||
"vite": "^6.2.3",
|
||||
"vite-plugin-compression2": "^1.3.3",
|
||||
"vite-plugin-node-polyfills": "^0.23.0",
|
||||
"vite-plugin-pwa": "^0.21.2"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@ export * from './artifacts';
|
|||
export * from './types';
|
||||
export * from './menus';
|
||||
export * from './tools';
|
||||
export * from './selector';
|
||||
export * from './assistants-types';
|
||||
export * from './agents-types';
|
||||
|
|
|
|||
23
client/src/common/selector.ts
Normal file
23
client/src/common/selector.ts
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import React from 'react';
|
||||
import { TModelSpec, TStartupConfig } from 'librechat-data-provider';
|
||||
|
||||
export interface Endpoint {
|
||||
value: string;
|
||||
label: string;
|
||||
hasModels: boolean;
|
||||
models?: Array<{ name: string; isGlobal?: boolean }>;
|
||||
icon: React.ReactNode;
|
||||
agentNames?: Record<string, string>;
|
||||
assistantNames?: Record<string, string>;
|
||||
modelIcons?: Record<string, string | undefined>;
|
||||
}
|
||||
|
||||
export interface SelectedValues {
|
||||
endpoint: string | null;
|
||||
model: string | null;
|
||||
modelSpec: string | null;
|
||||
}
|
||||
|
||||
export interface ModelSelectorProps {
|
||||
startupConfig: TStartupConfig | undefined;
|
||||
}
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
import { RefObject } from 'react';
|
||||
import { FileSources } from 'librechat-data-provider';
|
||||
import type * as InputNumberPrimitive from 'rc-input-number';
|
||||
import type { ColumnDef } from '@tanstack/react-table';
|
||||
import type { SetterOrUpdater } from 'recoil';
|
||||
import type * as t from 'librechat-data-provider';
|
||||
import { FileSources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { UseMutationResult } from '@tanstack/react-query';
|
||||
import type * as InputNumberPrimitive from 'rc-input-number';
|
||||
import type { SetterOrUpdater, RecoilState } from 'recoil';
|
||||
import type { ColumnDef } from '@tanstack/react-table';
|
||||
import type * as t from 'librechat-data-provider';
|
||||
import type { LucideIcon } from 'lucide-react';
|
||||
import type { TranslationKeys } from '~/hooks';
|
||||
|
||||
|
|
@ -48,6 +48,14 @@ export type AudioChunk = {
|
|||
};
|
||||
};
|
||||
|
||||
export type BadgeItem = {
|
||||
id: string;
|
||||
icon: React.ComponentType<any>;
|
||||
label: string;
|
||||
atom: RecoilState<boolean>;
|
||||
isAvailable: boolean;
|
||||
};
|
||||
|
||||
export type AssistantListItem = {
|
||||
id: string;
|
||||
name: string;
|
||||
|
|
@ -488,6 +496,16 @@ export interface ExtendedFile {
|
|||
metadata?: t.TFile['metadata'];
|
||||
}
|
||||
|
||||
export interface ModelItemProps {
|
||||
modelName: string;
|
||||
endpoint: EModelEndpoint;
|
||||
isSelected: boolean;
|
||||
onSelect: () => void;
|
||||
onNavigateBack: () => void;
|
||||
icon?: JSX.Element;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export type ContextType = { navVisible: boolean; setNavVisible: (visible: boolean) => void };
|
||||
|
||||
export interface SwitcherProps {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ function AddMultiConvo() {
|
|||
const localize = useLocalize();
|
||||
|
||||
const clickHandler = () => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
|
||||
const { title: _t, ...convo } = conversation ?? ({} as TConversation);
|
||||
setAddedConvo({
|
||||
...convo,
|
||||
|
|
@ -42,7 +42,7 @@ function AddMultiConvo() {
|
|||
role="button"
|
||||
onClick={clickHandler}
|
||||
data-testid="parameters-button"
|
||||
className="inline-flex size-10 items-center justify-center rounded-lg border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary"
|
||||
className="inline-flex size-10 flex-shrink-0 items-center justify-center rounded-lg border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary"
|
||||
>
|
||||
<PlusCircle size={16} aria-label="Plus Icon" />
|
||||
</TooltipAnchor>
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import type { TMessage } from 'librechat-data-provider';
|
|||
import type { ChatFormValues } from '~/common';
|
||||
import { ChatContext, AddedChatContext, useFileMapContext, ChatFormProvider } from '~/Providers';
|
||||
import { useChatHelpers, useAddedResponse, useSSE } from '~/hooks';
|
||||
import ConversationStarters from './Input/ConversationStarters';
|
||||
import MessagesView from './Messages/MessagesView';
|
||||
import { Spinner } from '~/components/svg';
|
||||
import Presentation from './Presentation';
|
||||
|
|
@ -21,6 +22,7 @@ function ChatView({ index = 0 }: { index?: number }) {
|
|||
const { conversationId } = useParams();
|
||||
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
|
||||
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
|
||||
const centerFormOnLanding = useRecoilValue(store.centerFormOnLanding);
|
||||
|
||||
const fileMap = useFileMapContext();
|
||||
|
||||
|
|
@ -46,16 +48,20 @@ function ChatView({ index = 0 }: { index?: number }) {
|
|||
});
|
||||
|
||||
let content: JSX.Element | null | undefined;
|
||||
const isLandingPage = !messagesTree || messagesTree.length === 0;
|
||||
|
||||
if (isLoading && conversationId !== 'new') {
|
||||
content = (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<Spinner className="opacity-0" />
|
||||
<div className="relative flex-1 overflow-hidden overflow-y-auto">
|
||||
<div className="relative flex h-full items-center justify-center">
|
||||
<Spinner className="text-text-primary" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else if (messagesTree && messagesTree.length !== 0) {
|
||||
content = <MessagesView messagesTree={messagesTree} Header={<Header />} />;
|
||||
} else if (!isLandingPage) {
|
||||
content = <MessagesView messagesTree={messagesTree} />;
|
||||
} else {
|
||||
content = <Landing Header={<Header />} />;
|
||||
content = <Landing centerFormOnLanding={centerFormOnLanding} />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
|
@ -63,10 +69,29 @@ function ChatView({ index = 0 }: { index?: number }) {
|
|||
<ChatContext.Provider value={chatHelpers}>
|
||||
<AddedChatContext.Provider value={addedChatHelpers}>
|
||||
<Presentation>
|
||||
{content}
|
||||
<div className="w-full border-t-0 pl-0 pt-2 dark:border-white/20 md:w-[calc(100%-.5rem)] md:border-t-0 md:border-transparent md:pl-0 md:pt-0 md:dark:border-transparent">
|
||||
<ChatForm index={index} />
|
||||
<Footer />
|
||||
<div className="flex h-full w-full flex-col">
|
||||
{!isLoading && <Header />}
|
||||
|
||||
{isLandingPage ? (
|
||||
<>
|
||||
<div className="flex flex-1 flex-col items-center justify-end sm:justify-center">
|
||||
{content}
|
||||
<div className="w-full max-w-3xl transition-all duration-200 xl:max-w-4xl">
|
||||
<ChatForm index={index} />
|
||||
<ConversationStarters />
|
||||
</div>
|
||||
</div>
|
||||
<Footer />
|
||||
</>
|
||||
) : (
|
||||
<div className="flex h-full flex-col overflow-y-auto">
|
||||
{content}
|
||||
<div className="w-full">
|
||||
<ChatForm index={index} />
|
||||
<Footer />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Presentation>
|
||||
</AddedChatContext.Provider>
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
interface ConvoStarterProps {
|
||||
text: string;
|
||||
onClick: () => void;
|
||||
}
|
||||
|
||||
export default function ConvoStarter({ text, onClick }: ConvoStarterProps) {
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className="relative flex w-40 cursor-pointer flex-col gap-2 rounded-2xl border border-border-medium px-3 pb-4 pt-3 text-start align-top text-[15px] shadow-[0_0_2px_0_rgba(0,0,0,0.05),0_4px_6px_0_rgba(0,0,0,0.02)] transition-colors duration-300 ease-in-out fade-in hover:bg-surface-tertiary"
|
||||
>
|
||||
<p className="break-word line-clamp-3 overflow-hidden text-balance break-all text-text-secondary">
|
||||
{text}
|
||||
</p>
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
|
@ -79,7 +79,7 @@ export default function ExportAndShareMenu({
|
|||
<Ariakit.MenuButton
|
||||
id="export-menu-button"
|
||||
aria-label="Export options"
|
||||
className="inline-flex size-10 items-center justify-center rounded-lg border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary"
|
||||
className="inline-flex size-10 flex-shrink-0 items-center justify-center rounded-lg border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary"
|
||||
>
|
||||
<Upload
|
||||
className="icon-md text-text-secondary"
|
||||
|
|
@ -103,7 +103,6 @@ export default function ExportAndShareMenu({
|
|||
<ShareButton
|
||||
triggerRef={shareButtonRef}
|
||||
conversationId={conversation.conversationId ?? ''}
|
||||
title={conversation.title ?? ''}
|
||||
open={showShareDialog}
|
||||
onOpenChange={setShowShareDialog}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ export default function Footer({ className }: { className?: string }) {
|
|||
<React.Fragment key={`main-content-part-${index}`}>
|
||||
<ReactMarkdown
|
||||
components={{
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
a: ({ node: _n, href, children, ...otherProps }) => {
|
||||
return (
|
||||
<a
|
||||
|
|
@ -70,7 +69,7 @@ export default function Footer({ className }: { className?: string }) {
|
|||
</a>
|
||||
);
|
||||
},
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
|
||||
p: ({ node: _n, ...props }) => <span {...props} />,
|
||||
}}
|
||||
>
|
||||
|
|
@ -84,24 +83,29 @@ export default function Footer({ className }: { className?: string }) {
|
|||
);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={
|
||||
className ??
|
||||
'relative flex items-center justify-center gap-2 px-2 py-2 text-center text-xs text-text-primary md:px-[60px]'
|
||||
}
|
||||
role="contentinfo"
|
||||
>
|
||||
{footerElements.map((contentRender, index) => {
|
||||
const isLastElement = index === footerElements.length - 1;
|
||||
return (
|
||||
<React.Fragment key={`footer-element-${index}`}>
|
||||
{contentRender}
|
||||
{!isLastElement && (
|
||||
<div key={`separator-${index}`} className="h-2 border-r-[1px] border-border-medium" />
|
||||
)}
|
||||
</React.Fragment>
|
||||
);
|
||||
})}
|
||||
<div className="relative w-full">
|
||||
<div
|
||||
className={
|
||||
className ??
|
||||
'absolute bottom-0 left-0 right-0 hidden items-center justify-center gap-2 px-2 py-2 text-center text-xs text-text-primary sm:flex md:px-[60px]'
|
||||
}
|
||||
role="contentinfo"
|
||||
>
|
||||
{footerElements.map((contentRender, index) => {
|
||||
const isLastElement = index === footerElements.length - 1;
|
||||
return (
|
||||
<React.Fragment key={`footer-element-${index}`}>
|
||||
{contentRender}
|
||||
{!isLastElement && (
|
||||
<div
|
||||
key={`separator-${index}`}
|
||||
className="h-2 border-r-[1px] border-border-medium"
|
||||
/>
|
||||
)}
|
||||
</React.Fragment>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,12 +2,13 @@ import { useMemo } from 'react';
|
|||
import { useOutletContext } from 'react-router-dom';
|
||||
import { getConfigDefaults, PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import type { ContextType } from '~/common';
|
||||
import { EndpointsMenu, ModelSpecsMenu, PresetsMenu, HeaderNewChat } from './Menus';
|
||||
import ModelSelector from './Menus/Endpoints/ModelSelector';
|
||||
import { PresetsMenu, HeaderNewChat } from './Menus';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import ExportAndShareMenu from './ExportAndShareMenu';
|
||||
import { useMediaQuery, useHasAccess } from '~/hooks';
|
||||
import HeaderOptions from './Input/HeaderOptions';
|
||||
import BookmarkMenu from './Menus/BookmarkMenu';
|
||||
import { TemporaryChat } from './TemporaryChat';
|
||||
import AddMultiConvo from './AddMultiConvo';
|
||||
|
||||
const defaultInterface = getConfigDefaults().interface;
|
||||
|
|
@ -15,7 +16,6 @@ const defaultInterface = getConfigDefaults().interface;
|
|||
export default function Header() {
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const { navVisible } = useOutletContext<ContextType>();
|
||||
const modelSpecs = useMemo(() => startupConfig?.modelSpecs?.list ?? [], [startupConfig]);
|
||||
const interfaceConfig = useMemo(
|
||||
() => startupConfig?.interface ?? defaultInterface,
|
||||
[startupConfig],
|
||||
|
|
@ -34,24 +34,30 @@ export default function Header() {
|
|||
const isSmallScreen = useMediaQuery('(max-width: 768px)');
|
||||
|
||||
return (
|
||||
<div className="sticky top-0 z-10 flex h-14 w-full items-center justify-between bg-white p-2 font-semibold dark:bg-gray-800 dark:text-white">
|
||||
<div className="sticky top-0 z-10 flex h-14 w-full items-center justify-between bg-white p-2 font-semibold text-text-primary dark:bg-gray-800">
|
||||
<div className="hide-scrollbar flex w-full items-center justify-between gap-2 overflow-x-auto">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="mx-2 flex items-center gap-2">
|
||||
{!navVisible && <HeaderNewChat />}
|
||||
{interfaceConfig.endpointsMenu === true && <EndpointsMenu />}
|
||||
{modelSpecs.length > 0 && <ModelSpecsMenu modelSpecs={modelSpecs} />}
|
||||
{<HeaderOptions interfaceConfig={interfaceConfig} />}
|
||||
{interfaceConfig.presets === true && <PresetsMenu />}
|
||||
{<ModelSelector startupConfig={startupConfig} />}
|
||||
{interfaceConfig.presets === true && interfaceConfig.modelSelect && <PresetsMenu />}
|
||||
{hasAccessToBookmarks === true && <BookmarkMenu />}
|
||||
{hasAccessToMultiConvo === true && <AddMultiConvo />}
|
||||
{isSmallScreen && (
|
||||
<ExportAndShareMenu
|
||||
isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false}
|
||||
/>
|
||||
<>
|
||||
<ExportAndShareMenu
|
||||
isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false}
|
||||
/>
|
||||
<TemporaryChat />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
{!isSmallScreen && (
|
||||
<ExportAndShareMenu isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false} />
|
||||
<div className="flex items-center gap-2">
|
||||
<ExportAndShareMenu
|
||||
isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false}
|
||||
/>
|
||||
<TemporaryChat />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{/* Empty div for spacing */}
|
||||
|
|
|
|||
|
|
@ -7,14 +7,12 @@ import { globalAudioId } from '~/common';
|
|||
import { cn } from '~/utils';
|
||||
|
||||
export default function AudioRecorder({
|
||||
isRTL,
|
||||
disabled,
|
||||
ask,
|
||||
methods,
|
||||
textAreaRef,
|
||||
isSubmitting,
|
||||
}: {
|
||||
isRTL: boolean;
|
||||
disabled: boolean;
|
||||
ask: (data: { text: string }) => void;
|
||||
methods: ReturnType<typeof useChatFormContext>;
|
||||
|
|
@ -90,9 +88,7 @@ export default function AudioRecorder({
|
|||
onClick={isListening === true ? handleStopRecording : handleStartRecording}
|
||||
disabled={disabled}
|
||||
className={cn(
|
||||
'absolute flex size-[35px] items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover',
|
||||
isRTL ? 'bottom-2 left-2' : 'bottom-2 right-2',
|
||||
disabled ? 'cursor-not-allowed opacity-50' : 'cursor-pointer',
|
||||
'flex size-9 items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover',
|
||||
)}
|
||||
title={localize('com_ui_use_micrphone')}
|
||||
aria-pressed={isListening}
|
||||
|
|
|
|||
369
client/src/components/Chat/Input/BadgeRow.tsx
Normal file
369
client/src/components/Chat/Input/BadgeRow.tsx
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
import React, {
|
||||
useState,
|
||||
useRef,
|
||||
useEffect,
|
||||
useCallback,
|
||||
useMemo,
|
||||
forwardRef,
|
||||
useReducer,
|
||||
} from 'react';
|
||||
import { useRecoilValue, useRecoilCallback } from 'recoil';
|
||||
import type { LucideIcon } from 'lucide-react';
|
||||
import type { BadgeItem } from '~/common';
|
||||
import { useChatBadges } from '~/hooks';
|
||||
import { Badge } from '~/components/ui';
|
||||
import store from '~/store';
|
||||
|
||||
interface BadgeRowProps {
|
||||
onChange: (badges: Pick<BadgeItem, 'id'>[]) => void;
|
||||
onToggle?: (badgeId: string, currentActive: boolean) => void;
|
||||
isInChat: boolean;
|
||||
}
|
||||
|
||||
interface BadgeWrapperProps {
|
||||
badge: BadgeItem;
|
||||
isEditing: boolean;
|
||||
isInChat: boolean;
|
||||
onToggle: (badge: BadgeItem) => void;
|
||||
onDelete: (id: string) => void;
|
||||
onMouseDown: (e: React.MouseEvent, badge: BadgeItem, isActive: boolean) => void;
|
||||
badgeRefs: React.MutableRefObject<Record<string, HTMLDivElement>>;
|
||||
}
|
||||
|
||||
const BadgeWrapper = React.memo(
|
||||
forwardRef<HTMLDivElement, BadgeWrapperProps>(
|
||||
({ badge, isEditing, isInChat, onToggle, onDelete, onMouseDown, badgeRefs }, ref) => {
|
||||
const isActive = badge.atom ? useRecoilValue(badge.atom) : false;
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={(el) => {
|
||||
if (el) {
|
||||
badgeRefs.current[badge.id] = el;
|
||||
}
|
||||
if (typeof ref === 'function') {
|
||||
ref(el);
|
||||
} else if (ref) {
|
||||
ref.current = el;
|
||||
}
|
||||
}}
|
||||
onMouseDown={(e) => onMouseDown(e, badge, isActive)}
|
||||
className={isEditing ? 'ios-wiggle badge-icon h-full' : 'badge-icon h-full'}
|
||||
>
|
||||
<Badge
|
||||
id={badge.id}
|
||||
icon={badge.icon as LucideIcon}
|
||||
label={badge.label}
|
||||
isActive={isActive}
|
||||
isEditing={isEditing}
|
||||
isAvailable={badge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
onToggle={() => onToggle(badge)}
|
||||
onBadgeAction={() => onDelete(badge.id)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
),
|
||||
(prevProps, nextProps) =>
|
||||
prevProps.badge.id === nextProps.badge.id &&
|
||||
prevProps.isEditing === nextProps.isEditing &&
|
||||
prevProps.isInChat === nextProps.isInChat &&
|
||||
prevProps.onToggle === nextProps.onToggle &&
|
||||
prevProps.onDelete === nextProps.onDelete &&
|
||||
prevProps.onMouseDown === nextProps.onMouseDown &&
|
||||
prevProps.badgeRefs === nextProps.badgeRefs,
|
||||
);
|
||||
|
||||
BadgeWrapper.displayName = 'BadgeWrapper';
|
||||
|
||||
interface DragState {
|
||||
draggedBadge: BadgeItem | null;
|
||||
mouseX: number;
|
||||
offsetX: number;
|
||||
insertIndex: number | null;
|
||||
draggedBadgeActive: boolean;
|
||||
}
|
||||
|
||||
type DragAction =
|
||||
| {
|
||||
type: 'START_DRAG';
|
||||
badge: BadgeItem;
|
||||
mouseX: number;
|
||||
offsetX: number;
|
||||
insertIndex: number;
|
||||
isActive: boolean;
|
||||
}
|
||||
| { type: 'UPDATE_POSITION'; mouseX: number; insertIndex: number }
|
||||
| { type: 'END_DRAG' };
|
||||
|
||||
const dragReducer = (state: DragState, action: DragAction): DragState => {
|
||||
switch (action.type) {
|
||||
case 'START_DRAG':
|
||||
return {
|
||||
draggedBadge: action.badge,
|
||||
mouseX: action.mouseX,
|
||||
offsetX: action.offsetX,
|
||||
insertIndex: action.insertIndex,
|
||||
draggedBadgeActive: action.isActive,
|
||||
};
|
||||
case 'UPDATE_POSITION':
|
||||
return {
|
||||
...state,
|
||||
mouseX: action.mouseX,
|
||||
insertIndex: action.insertIndex,
|
||||
};
|
||||
case 'END_DRAG':
|
||||
return {
|
||||
draggedBadge: null,
|
||||
mouseX: 0,
|
||||
offsetX: 0,
|
||||
insertIndex: null,
|
||||
draggedBadgeActive: false,
|
||||
};
|
||||
default:
|
||||
return state;
|
||||
}
|
||||
};
|
||||
|
||||
export function BadgeRow({ onChange, onToggle, isInChat }: BadgeRowProps) {
|
||||
const [orderedBadges, setOrderedBadges] = useState<BadgeItem[]>([]);
|
||||
const [dragState, dispatch] = useReducer(dragReducer, {
|
||||
draggedBadge: null,
|
||||
mouseX: 0,
|
||||
offsetX: 0,
|
||||
insertIndex: null,
|
||||
draggedBadgeActive: false,
|
||||
});
|
||||
|
||||
const badgeRefs = useRef<Record<string, HTMLDivElement>>({});
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const animationFrame = useRef<number | null>(null);
|
||||
const containerRectRef = useRef<DOMRect | null>(null);
|
||||
|
||||
const allBadges = useChatBadges() || [];
|
||||
const isEditing = useRecoilValue(store.isEditingBadges);
|
||||
|
||||
const badges = useMemo(
|
||||
() => allBadges.filter((badge) => badge.isAvailable !== false),
|
||||
[allBadges],
|
||||
);
|
||||
|
||||
const toggleBadge = useRecoilCallback(
|
||||
({ snapshot, set }) =>
|
||||
async (badgeAtom: any) => {
|
||||
const current = await snapshot.getPromise(badgeAtom);
|
||||
set(badgeAtom, !current);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setOrderedBadges((prev) => {
|
||||
const currentIds = new Set(prev.map((b) => b.id));
|
||||
const newBadges = badges.filter((b) => !currentIds.has(b.id));
|
||||
return newBadges.length > 0 ? [...prev, ...newBadges] : prev;
|
||||
});
|
||||
}, [badges]);
|
||||
|
||||
const tempBadges = dragState.draggedBadge
|
||||
? orderedBadges.filter((b) => b.id !== dragState.draggedBadge?.id)
|
||||
: orderedBadges;
|
||||
const ghostBadge = dragState.draggedBadge || null;
|
||||
|
||||
const calculateInsertIndex = useCallback(
|
||||
(currentMouseX: number): number => {
|
||||
if (!dragState.draggedBadge || !containerRef.current || !containerRectRef.current) {
|
||||
return 0;
|
||||
}
|
||||
const relativeMouseX = currentMouseX - containerRectRef.current.left;
|
||||
const refs = tempBadges.map((b) => badgeRefs.current[b.id]).filter(Boolean);
|
||||
if (refs.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
let idx = 0;
|
||||
for (let i = 0; i < refs.length; i++) {
|
||||
const rect = refs[i].getBoundingClientRect();
|
||||
const relativeLeft = rect.left - containerRectRef.current.left;
|
||||
const relativeCenter = relativeLeft + rect.width / 2;
|
||||
if (relativeMouseX < relativeCenter) {
|
||||
break;
|
||||
}
|
||||
idx = i + 1;
|
||||
}
|
||||
return idx;
|
||||
},
|
||||
[dragState.draggedBadge, tempBadges],
|
||||
);
|
||||
|
||||
const handleMouseDown = useCallback(
|
||||
(e: React.MouseEvent, badge: BadgeItem, isActive: boolean) => {
|
||||
if (!isEditing || !containerRef.current) {
|
||||
return;
|
||||
}
|
||||
const el = badgeRefs.current[badge.id];
|
||||
if (!el) {
|
||||
return;
|
||||
}
|
||||
const rect = el.getBoundingClientRect();
|
||||
const offsetX = e.clientX - rect.left;
|
||||
const mouseX = e.clientX;
|
||||
const initialIndex = orderedBadges.findIndex((b) => b.id === badge.id);
|
||||
containerRectRef.current = containerRef.current.getBoundingClientRect();
|
||||
dispatch({
|
||||
type: 'START_DRAG',
|
||||
badge,
|
||||
mouseX,
|
||||
offsetX,
|
||||
insertIndex: initialIndex,
|
||||
isActive,
|
||||
});
|
||||
},
|
||||
[isEditing, orderedBadges],
|
||||
);
|
||||
|
||||
const handleMouseMove = useCallback(
|
||||
(e: MouseEvent) => {
|
||||
if (!dragState.draggedBadge) {
|
||||
return;
|
||||
}
|
||||
if (animationFrame.current) {
|
||||
cancelAnimationFrame(animationFrame.current);
|
||||
}
|
||||
animationFrame.current = requestAnimationFrame(() => {
|
||||
const newMouseX = e.clientX;
|
||||
const newInsertIndex = calculateInsertIndex(newMouseX);
|
||||
if (newInsertIndex !== dragState.insertIndex) {
|
||||
dispatch({ type: 'UPDATE_POSITION', mouseX: newMouseX, insertIndex: newInsertIndex });
|
||||
} else {
|
||||
dispatch({
|
||||
type: 'UPDATE_POSITION',
|
||||
mouseX: newMouseX,
|
||||
insertIndex: dragState.insertIndex,
|
||||
});
|
||||
}
|
||||
});
|
||||
},
|
||||
[dragState.draggedBadge, dragState.insertIndex, calculateInsertIndex],
|
||||
);
|
||||
|
||||
const handleMouseUp = useCallback(() => {
|
||||
if (dragState.draggedBadge && dragState.insertIndex !== null) {
|
||||
const otherBadges = orderedBadges.filter((b) => b.id !== dragState.draggedBadge?.id);
|
||||
const newBadges = [
|
||||
...otherBadges.slice(0, dragState.insertIndex),
|
||||
dragState.draggedBadge,
|
||||
...otherBadges.slice(dragState.insertIndex),
|
||||
];
|
||||
setOrderedBadges(newBadges);
|
||||
onChange(newBadges.map((badge) => ({ id: badge.id })));
|
||||
}
|
||||
dispatch({ type: 'END_DRAG' });
|
||||
containerRectRef.current = null;
|
||||
}, [dragState.draggedBadge, dragState.insertIndex, orderedBadges, onChange]);
|
||||
|
||||
const handleDelete = useCallback(
|
||||
(badgeId: string) => {
|
||||
const newBadges = orderedBadges.filter((b) => b.id !== badgeId);
|
||||
setOrderedBadges(newBadges);
|
||||
onChange(newBadges.map((badge) => ({ id: badge.id })));
|
||||
},
|
||||
[orderedBadges, onChange],
|
||||
);
|
||||
|
||||
const handleBadgeToggle = useCallback(
|
||||
(badge: BadgeItem) => {
|
||||
if (badge.atom) {
|
||||
toggleBadge(badge.atom);
|
||||
}
|
||||
if (onToggle) {
|
||||
onToggle(badge.id, !!badge.atom);
|
||||
}
|
||||
},
|
||||
[toggleBadge, onToggle],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!dragState.draggedBadge) {
|
||||
return;
|
||||
}
|
||||
document.addEventListener('mousemove', handleMouseMove);
|
||||
document.addEventListener('mouseup', handleMouseUp);
|
||||
return () => {
|
||||
document.removeEventListener('mousemove', handleMouseMove);
|
||||
document.removeEventListener('mouseup', handleMouseUp);
|
||||
if (animationFrame.current) {
|
||||
cancelAnimationFrame(animationFrame.current);
|
||||
animationFrame.current = null;
|
||||
}
|
||||
};
|
||||
}, [dragState.draggedBadge, handleMouseMove, handleMouseUp]);
|
||||
|
||||
return (
|
||||
<div ref={containerRef} className="relative flex flex-wrap items-center gap-2">
|
||||
{tempBadges.map((badge, index) => (
|
||||
<React.Fragment key={badge.id}>
|
||||
{dragState.draggedBadge && dragState.insertIndex === index && ghostBadge && (
|
||||
<div className="badge-icon h-full">
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isEditing={isEditing}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<BadgeWrapper
|
||||
badge={badge}
|
||||
isEditing={isEditing}
|
||||
isInChat={isInChat}
|
||||
onToggle={handleBadgeToggle}
|
||||
onDelete={handleDelete}
|
||||
onMouseDown={handleMouseDown}
|
||||
badgeRefs={badgeRefs}
|
||||
/>
|
||||
</React.Fragment>
|
||||
))}
|
||||
{dragState.draggedBadge && dragState.insertIndex === tempBadges.length && ghostBadge && (
|
||||
<div className="badge-icon h-full">
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isEditing={isEditing}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{ghostBadge && (
|
||||
<div
|
||||
className="ghost-badge h-full"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
transform: `translateX(${dragState.mouseX - dragState.offsetX - (containerRectRef.current?.left || 0)}px)`,
|
||||
zIndex: 10,
|
||||
pointerEvents: 'none',
|
||||
}}
|
||||
>
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
isEditing
|
||||
isDragging
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,11 +1,7 @@
|
|||
import { memo, useRef, useMemo, useEffect, useState } from 'react';
|
||||
import { memo, useRef, useMemo, useEffect, useState, useCallback } from 'react';
|
||||
import { useWatch } from 'react-hook-form';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import {
|
||||
supportsFiles,
|
||||
mergeFileConfig,
|
||||
isAssistantsEndpoint,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import { Constants, isAssistantsEndpoint } from 'librechat-data-provider';
|
||||
import {
|
||||
useChatContext,
|
||||
useChatFormContext,
|
||||
|
|
@ -20,47 +16,107 @@ import {
|
|||
useQueryParams,
|
||||
useSubmitMessage,
|
||||
} from '~/hooks';
|
||||
import { cn, removeFocusRings, checkIfScrollable } from '~/utils';
|
||||
import FileFormWrapper from './Files/FileFormWrapper';
|
||||
import { TextareaAutosize } from '~/components/ui';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import { TemporaryChat } from './TemporaryChat';
|
||||
import { mainTextareaId, BadgeItem } from '~/common';
|
||||
import AttachFileChat from './Files/AttachFileChat';
|
||||
import FileFormChat from './Files/FileFormChat';
|
||||
import { TextareaAutosize } from '~/components';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import TextareaHeader from './TextareaHeader';
|
||||
import PromptsCommand from './PromptsCommand';
|
||||
import AudioRecorder from './AudioRecorder';
|
||||
import { mainTextareaId } from '~/common';
|
||||
import CollapseChat from './CollapseChat';
|
||||
import StreamAudio from './StreamAudio';
|
||||
import StopButton from './StopButton';
|
||||
import SendButton from './SendButton';
|
||||
import { BadgeRow } from './BadgeRow';
|
||||
import EditBadges from './EditBadges';
|
||||
import Mention from './Mention';
|
||||
import store from '~/store';
|
||||
|
||||
const ChatForm = ({ index = 0 }) => {
|
||||
const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
||||
const submitButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
useQueryParams({ textAreaRef });
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const [isCollapsed, setIsCollapsed] = useState(false);
|
||||
const [isScrollable, setIsScrollable] = useState(false);
|
||||
|
||||
const SpeechToText = useRecoilValue(store.speechToText);
|
||||
const TextToSpeech = useRecoilValue(store.textToSpeech);
|
||||
const automaticPlayback = useRecoilValue(store.automaticPlayback);
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
const [isTemporaryChat, setIsTemporaryChat] = useRecoilState<boolean>(store.isTemporary);
|
||||
const [, setIsScrollable] = useState(false);
|
||||
const [visualRowCount, setVisualRowCount] = useState(1);
|
||||
const [isTextAreaFocused, setIsTextAreaFocused] = useState(false);
|
||||
const [backupBadges, setBackupBadges] = useState<Pick<BadgeItem, 'id'>[]>([]);
|
||||
|
||||
const isSearching = useRecoilValue(store.isSearching);
|
||||
const SpeechToText = useRecoilValue(store.speechToText);
|
||||
const TextToSpeech = useRecoilValue(store.textToSpeech);
|
||||
const chatDirection = useRecoilValue(store.chatDirection);
|
||||
const automaticPlayback = useRecoilValue(store.automaticPlayback);
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
const centerFormOnLanding = useRecoilValue(store.centerFormOnLanding);
|
||||
const isTemporary = useRecoilValue(store.isTemporary);
|
||||
|
||||
const [badges, setBadges] = useRecoilState(store.chatBadges);
|
||||
const [isEditingBadges, setIsEditingBadges] = useRecoilState(store.isEditingBadges);
|
||||
const [showStopButton, setShowStopButton] = useRecoilState(store.showStopButtonByIndex(index));
|
||||
const [showPlusPopover, setShowPlusPopover] = useRecoilState(store.showPlusPopoverFamily(index));
|
||||
const [showMentionPopover, setShowMentionPopover] = useRecoilState(
|
||||
store.showMentionPopoverFamily(index),
|
||||
);
|
||||
|
||||
const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
|
||||
const isRTL = chatDirection === 'rtl';
|
||||
|
||||
const { requiresKey } = useRequiresKey();
|
||||
const methods = useChatFormContext();
|
||||
const {
|
||||
files,
|
||||
setFiles,
|
||||
conversation,
|
||||
isSubmitting,
|
||||
filesLoading,
|
||||
newConversation,
|
||||
handleStopGenerating,
|
||||
} = useChatContext();
|
||||
const {
|
||||
addedIndex,
|
||||
generateConversation,
|
||||
conversation: addedConvo,
|
||||
setConversation: setAddedConvo,
|
||||
isSubmitting: isSubmittingAdded,
|
||||
} = useAddedChatContext();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
const showStopAdded = useRecoilValue(store.showStopButtonByIndex(addedIndex));
|
||||
|
||||
const endpoint = useMemo(
|
||||
() => conversation?.endpointType ?? conversation?.endpoint,
|
||||
[conversation?.endpointType, conversation?.endpoint],
|
||||
);
|
||||
|
||||
const isRTL = useMemo(() => chatDirection === 'rtl', [chatDirection.toLowerCase()]);
|
||||
const invalidAssistant = useMemo(
|
||||
() =>
|
||||
isAssistantsEndpoint(endpoint) &&
|
||||
(!(conversation?.assistant_id ?? '') ||
|
||||
!assistantMap?.[endpoint ?? '']?.[conversation?.assistant_id ?? '']),
|
||||
[conversation?.assistant_id, endpoint, assistantMap],
|
||||
);
|
||||
const disableInputs = useMemo(
|
||||
() => requiresKey || invalidAssistant,
|
||||
[requiresKey, invalidAssistant],
|
||||
);
|
||||
|
||||
const handleContainerClick = useCallback(() => {
|
||||
textAreaRef.current?.focus();
|
||||
}, []);
|
||||
|
||||
const handleFocusOrClick = useCallback(() => {
|
||||
if (isCollapsed) {
|
||||
setIsCollapsed(false);
|
||||
}
|
||||
}, [isCollapsed]);
|
||||
|
||||
useAutoSave({
|
||||
conversationId: conversation?.conversationId,
|
||||
textAreaRef,
|
||||
files,
|
||||
setFiles,
|
||||
});
|
||||
|
||||
const { submitMessage, submitPrompt } = useSubmitMessage();
|
||||
const handleKeyUp = useHandleKeyUp({
|
||||
index,
|
||||
textAreaRef,
|
||||
|
|
@ -71,65 +127,22 @@ const ChatForm = ({ index = 0 }) => {
|
|||
textAreaRef,
|
||||
submitButtonRef,
|
||||
setIsScrollable,
|
||||
disabled: !!(requiresKey ?? false),
|
||||
disabled: disableInputs,
|
||||
});
|
||||
|
||||
const {
|
||||
files,
|
||||
setFiles,
|
||||
conversation,
|
||||
isSubmitting,
|
||||
filesLoading,
|
||||
newConversation,
|
||||
handleStopGenerating,
|
||||
} = useChatContext();
|
||||
const methods = useChatFormContext();
|
||||
const {
|
||||
addedIndex,
|
||||
generateConversation,
|
||||
conversation: addedConvo,
|
||||
setConversation: setAddedConvo,
|
||||
isSubmitting: isSubmittingAdded,
|
||||
} = useAddedChatContext();
|
||||
const showStopAdded = useRecoilValue(store.showStopButtonByIndex(addedIndex));
|
||||
|
||||
useAutoSave({
|
||||
conversationId: useMemo(() => conversation?.conversationId, [conversation]),
|
||||
textAreaRef,
|
||||
files,
|
||||
setFiles,
|
||||
});
|
||||
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
const { submitMessage, submitPrompt } = useSubmitMessage();
|
||||
|
||||
const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
const endpoint = endpointType ?? _endpoint;
|
||||
|
||||
const { data: fileConfig = defaultFileConfig } = useGetFileConfig({
|
||||
select: (data) => mergeFileConfig(data),
|
||||
});
|
||||
|
||||
const endpointFileConfig = fileConfig.endpoints[endpoint ?? ''];
|
||||
const invalidAssistant = useMemo(
|
||||
() =>
|
||||
isAssistantsEndpoint(conversation?.endpoint) &&
|
||||
(!(conversation?.assistant_id ?? '') ||
|
||||
!assistantMap?.[conversation?.endpoint ?? ''][conversation?.assistant_id ?? '']),
|
||||
[conversation?.assistant_id, conversation?.endpoint, assistantMap],
|
||||
);
|
||||
const disableInputs = useMemo(
|
||||
() => !!((requiresKey ?? false) || invalidAssistant),
|
||||
[requiresKey, invalidAssistant],
|
||||
);
|
||||
useQueryParams({ textAreaRef });
|
||||
|
||||
const { ref, ...registerProps } = methods.register('text', {
|
||||
required: true,
|
||||
onChange: (e) => {
|
||||
methods.setValue('text', e.target.value, { shouldValidate: true });
|
||||
},
|
||||
onChange: useCallback(
|
||||
(e: React.ChangeEvent<HTMLTextAreaElement>) =>
|
||||
methods.setValue('text', e.target.value, { shouldValidate: true }),
|
||||
[methods],
|
||||
),
|
||||
});
|
||||
|
||||
const textValue = useWatch({ control: methods.control, name: 'text' });
|
||||
|
||||
useEffect(() => {
|
||||
if (!isSearching && textAreaRef.current && !disableInputs) {
|
||||
textAreaRef.current.focus();
|
||||
|
|
@ -138,33 +151,58 @@ const ChatForm = ({ index = 0 }) => {
|
|||
|
||||
useEffect(() => {
|
||||
if (textAreaRef.current) {
|
||||
checkIfScrollable(textAreaRef.current);
|
||||
const style = window.getComputedStyle(textAreaRef.current);
|
||||
const lineHeight = parseFloat(style.lineHeight);
|
||||
setVisualRowCount(Math.floor(textAreaRef.current.scrollHeight / lineHeight));
|
||||
}
|
||||
}, [textValue]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isEditingBadges && backupBadges.length === 0) {
|
||||
setBackupBadges([...badges]);
|
||||
}
|
||||
}, [isEditingBadges, badges, backupBadges.length]);
|
||||
|
||||
const handleSaveBadges = useCallback(() => {
|
||||
setIsEditingBadges(false);
|
||||
setBackupBadges([]);
|
||||
}, []);
|
||||
|
||||
const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? endpoint ?? ''] ?? false;
|
||||
const isUploadDisabled: boolean = endpointFileConfig?.disabled ?? false;
|
||||
const handleCancelBadges = useCallback(() => {
|
||||
if (backupBadges.length > 0) {
|
||||
setBadges([...backupBadges]);
|
||||
}
|
||||
setIsEditingBadges(false);
|
||||
setBackupBadges([]);
|
||||
}, [backupBadges, setBadges]);
|
||||
|
||||
const baseClasses = cn(
|
||||
'md:py-3.5 m-0 w-full resize-none py-[13px] bg-surface-tertiary placeholder-black/50 dark:placeholder-white/50 [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]',
|
||||
isCollapsed ? 'max-h-[52px]' : 'max-h-[65vh] md:max-h-[75vh]',
|
||||
const isMoreThanThreeRows = visualRowCount > 3;
|
||||
|
||||
const baseClasses = useMemo(
|
||||
() =>
|
||||
cn(
|
||||
'md:py-3.5 m-0 w-full resize-none py-[13px] placeholder-black/50 bg-transparent dark:placeholder-white/50 [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]',
|
||||
isCollapsed ? 'max-h-[52px]' : 'max-h-[45vh] md:max-h-[55vh]',
|
||||
isMoreThanThreeRows ? 'pl-5' : 'px-5',
|
||||
),
|
||||
[isCollapsed, isMoreThanThreeRows],
|
||||
);
|
||||
|
||||
const uploadActive = endpointSupportsFiles && !isUploadDisabled;
|
||||
const speechClass = isRTL
|
||||
? `pr-${uploadActive ? '12' : '4'} pl-12`
|
||||
: `pl-${uploadActive ? '12' : '4'} pr-12`;
|
||||
|
||||
return (
|
||||
<form
|
||||
onSubmit={methods.handleSubmit((data) => submitMessage(data))}
|
||||
onSubmit={methods.handleSubmit(submitMessage)}
|
||||
className={cn(
|
||||
'mx-auto flex flex-row gap-3 pl-2 transition-all duration-200 last:mb-2',
|
||||
maximizeChatSpace ? 'w-full max-w-full' : 'md:max-w-2xl xl:max-w-3xl',
|
||||
'mx-auto flex flex-row gap-3 sm:px-2',
|
||||
maximizeChatSpace ? 'w-full max-w-full' : 'md:max-w-3xl xl:max-w-4xl',
|
||||
centerFormOnLanding &&
|
||||
(!conversation?.conversationId || conversation?.conversationId === Constants.NEW_CONVO) &&
|
||||
!isSubmitting
|
||||
? 'transition-all duration-200 sm:mb-28'
|
||||
: 'sm:mb-10',
|
||||
)}
|
||||
>
|
||||
<div className="relative flex h-full flex-1 items-stretch md:flex-col">
|
||||
<div className="flex w-full items-center">
|
||||
<div className={cn('flex w-full items-center', isRTL && 'flex-row-reverse')}>
|
||||
{showPlusPopover && !isAssistantsEndpoint(endpoint) && (
|
||||
<Mention
|
||||
setShowMentionPopover={setShowPlusPopover}
|
||||
|
|
@ -183,90 +221,109 @@ const ChatForm = ({ index = 0 }) => {
|
|||
/>
|
||||
)}
|
||||
<PromptsCommand index={index} textAreaRef={textAreaRef} submitPrompt={submitPrompt} />
|
||||
<div className="transitional-all relative flex w-full flex-grow flex-col overflow-hidden rounded-3xl bg-surface-tertiary text-text-primary duration-200">
|
||||
<TemporaryChat
|
||||
isTemporaryChat={isTemporaryChat}
|
||||
setIsTemporaryChat={setIsTemporaryChat}
|
||||
/>
|
||||
<div
|
||||
onClick={handleContainerClick}
|
||||
className={cn(
|
||||
'relative flex w-full flex-grow flex-col overflow-hidden rounded-t-3xl border pb-4 text-text-primary transition-all duration-200 sm:rounded-3xl sm:pb-0',
|
||||
isTextAreaFocused ? 'shadow-lg' : 'shadow-md',
|
||||
isTemporary
|
||||
? 'border-violet-800/60 bg-violet-950/10'
|
||||
: 'border-border-light bg-surface-chat',
|
||||
)}
|
||||
>
|
||||
<TextareaHeader addedConvo={addedConvo} setAddedConvo={setAddedConvo} />
|
||||
<FileFormWrapper disableInputs={disableInputs}>
|
||||
{endpoint && (
|
||||
<>
|
||||
<EditBadges
|
||||
isEditingChatBadges={isEditingBadges}
|
||||
handleCancelBadges={handleCancelBadges}
|
||||
handleSaveBadges={handleSaveBadges}
|
||||
setBadges={setBadges}
|
||||
/>
|
||||
<FileFormChat disableInputs={disableInputs} />
|
||||
{endpoint && (
|
||||
<div className={cn('flex', isRTL ? 'flex-row-reverse' : 'flex-row')}>
|
||||
<TextareaAutosize
|
||||
{...registerProps}
|
||||
ref={(e) => {
|
||||
ref(e);
|
||||
(textAreaRef as React.MutableRefObject<HTMLTextAreaElement | null>).current = e;
|
||||
}}
|
||||
disabled={disableInputs}
|
||||
onPaste={handlePaste}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
onCompositionStart={handleCompositionStart}
|
||||
onCompositionEnd={handleCompositionEnd}
|
||||
id={mainTextareaId}
|
||||
tabIndex={0}
|
||||
data-testid="text-input"
|
||||
rows={1}
|
||||
onFocus={() => {
|
||||
handleFocusOrClick();
|
||||
setIsTextAreaFocused(true);
|
||||
}}
|
||||
onBlur={setIsTextAreaFocused.bind(null, false)}
|
||||
onClick={handleFocusOrClick}
|
||||
style={{ height: 44, overflowY: 'auto' }}
|
||||
className={cn(
|
||||
baseClasses,
|
||||
removeFocusRings,
|
||||
'transition-[max-height] duration-200',
|
||||
)}
|
||||
/>
|
||||
<div className="flex flex-col items-start justify-start pt-1.5">
|
||||
<CollapseChat
|
||||
isCollapsed={isCollapsed}
|
||||
isScrollable={isScrollable}
|
||||
isScrollable={isMoreThanThreeRows}
|
||||
setIsCollapsed={setIsCollapsed}
|
||||
/>
|
||||
<TextareaAutosize
|
||||
{...registerProps}
|
||||
ref={(e) => {
|
||||
ref(e);
|
||||
textAreaRef.current = e;
|
||||
}}
|
||||
disabled={disableInputs}
|
||||
onPaste={handlePaste}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
onHeightChange={() => {
|
||||
if (textAreaRef.current) {
|
||||
const scrollable = checkIfScrollable(textAreaRef.current);
|
||||
setIsScrollable(scrollable);
|
||||
}
|
||||
}}
|
||||
onCompositionStart={handleCompositionStart}
|
||||
onCompositionEnd={handleCompositionEnd}
|
||||
id={mainTextareaId}
|
||||
tabIndex={0}
|
||||
data-testid="text-input"
|
||||
rows={1}
|
||||
onFocus={() => isCollapsed && setIsCollapsed(false)}
|
||||
onClick={() => isCollapsed && setIsCollapsed(false)}
|
||||
style={{ height: 44, overflowY: 'auto' }}
|
||||
className={cn(
|
||||
baseClasses,
|
||||
speechClass,
|
||||
removeFocusRings,
|
||||
'transition-[max-height] duration-200',
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
'items-between flex gap-2 pb-2',
|
||||
isRTL ? 'flex-row-reverse' : 'flex-row',
|
||||
)}
|
||||
</FileFormWrapper>
|
||||
{SpeechToText && (
|
||||
<AudioRecorder
|
||||
isRTL={isRTL}
|
||||
methods={methods}
|
||||
ask={submitMessage}
|
||||
textAreaRef={textAreaRef}
|
||||
disabled={!!disableInputs}
|
||||
isSubmitting={isSubmitting}
|
||||
>
|
||||
<div className={`${isRTL ? 'mr-2' : 'ml-2'}`}>
|
||||
<AttachFileChat disableInputs={disableInputs} />
|
||||
</div>
|
||||
<BadgeRow
|
||||
onChange={(newBadges) => setBadges(newBadges)}
|
||||
isInChat={
|
||||
Array.isArray(conversation?.messages) && conversation.messages.length >= 1
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{TextToSpeech && automaticPlayback && <StreamAudio index={index} />}
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
'mb-[5px] ml-[8px] flex flex-col items-end justify-end',
|
||||
isRTL && 'order-first mr-[8px]',
|
||||
)}
|
||||
style={{ alignSelf: 'flex-end' }}
|
||||
>
|
||||
{(isSubmitting || isSubmittingAdded) && (showStopButton || showStopAdded) ? (
|
||||
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
|
||||
) : (
|
||||
endpoint && (
|
||||
<SendButton
|
||||
ref={submitButtonRef}
|
||||
control={methods.control}
|
||||
disabled={!!(filesLoading || isSubmitting || disableInputs)}
|
||||
<div className="mx-auto flex" />
|
||||
{SpeechToText && (
|
||||
<AudioRecorder
|
||||
methods={methods}
|
||||
ask={submitMessage}
|
||||
textAreaRef={textAreaRef}
|
||||
disabled={disableInputs}
|
||||
isSubmitting={isSubmitting}
|
||||
/>
|
||||
)
|
||||
)}
|
||||
)}
|
||||
<div className={`${isRTL ? 'ml-2' : 'mr-2'}`}>
|
||||
{(isSubmitting || isSubmittingAdded) && (showStopButton || showStopAdded) ? (
|
||||
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
|
||||
) : (
|
||||
endpoint && (
|
||||
<SendButton
|
||||
ref={submitButtonRef}
|
||||
control={methods.control}
|
||||
disabled={filesLoading || isSubmitting || disableInputs}
|
||||
/>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{TextToSpeech && automaticPlayback && <StreamAudio index={index} />}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
export default memo(ChatForm);
|
||||
export default ChatForm;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import React from 'react';
|
||||
import { Minimize2 } from 'lucide-react';
|
||||
import { ChevronDown, ChevronUp } from 'lucide-react';
|
||||
import { TooltipAnchor } from '~/components/ui';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
|
@ -18,23 +18,37 @@ const CollapseChat = ({
|
|||
return null;
|
||||
}
|
||||
|
||||
if (isCollapsed) {
|
||||
return null;
|
||||
}
|
||||
const description = isCollapsed
|
||||
? localize('com_ui_expand_chat')
|
||||
: localize('com_ui_collapse_chat');
|
||||
|
||||
return (
|
||||
<TooltipAnchor
|
||||
role="button"
|
||||
description={localize('com_ui_collapse_chat')}
|
||||
aria-label={localize('com_ui_collapse_chat')}
|
||||
onClick={() => setIsCollapsed(true)}
|
||||
className={cn(
|
||||
'absolute right-2 top-2 z-10 size-[35px] rounded-full p-2 transition-colors',
|
||||
'hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
)}
|
||||
>
|
||||
<Minimize2 className="h-full w-full" />
|
||||
</TooltipAnchor>
|
||||
<div className="relative ml-auto items-end justify-end">
|
||||
<TooltipAnchor
|
||||
description={description}
|
||||
render={
|
||||
<button
|
||||
aria-label={description}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
setIsCollapsed((prev) => !prev);
|
||||
}}
|
||||
className={cn(
|
||||
// 'absolute right-1.5 top-1.5',
|
||||
'z-10 size-5 rounded-full transition-colors',
|
||||
'focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-primary focus-visible:ring-opacity-50',
|
||||
)}
|
||||
>
|
||||
{isCollapsed ? (
|
||||
<ChevronDown className="h-full w-full" />
|
||||
) : (
|
||||
<ChevronUp className="h-full w-full" />
|
||||
)}
|
||||
</button>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
85
client/src/components/Chat/Input/ConversationStarters.tsx
Normal file
85
client/src/components/Chat/Input/ConversationStarters.tsx
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
import { useMemo, useCallback } from 'react';
|
||||
import { EModelEndpoint, Constants } from 'librechat-data-provider';
|
||||
import { useChatContext, useAgentsMapContext, useAssistantsMapContext } from '~/Providers';
|
||||
import { useGetAssistantDocsQuery, useGetEndpointsQuery } from '~/data-provider';
|
||||
import { getIconEndpoint, getEntity } from '~/utils';
|
||||
import { useSubmitMessage } from '~/hooks';
|
||||
|
||||
const ConversationStarters = () => {
|
||||
const { conversation } = useChatContext();
|
||||
const agentsMap = useAgentsMapContext();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
|
||||
const endpointType = useMemo(() => {
|
||||
let ep = conversation?.endpoint ?? '';
|
||||
if (
|
||||
[
|
||||
EModelEndpoint.chatGPTBrowser,
|
||||
EModelEndpoint.azureOpenAI,
|
||||
EModelEndpoint.gptPlugins,
|
||||
].includes(ep as EModelEndpoint)
|
||||
) {
|
||||
ep = EModelEndpoint.openAI;
|
||||
}
|
||||
return getIconEndpoint({
|
||||
endpointsConfig,
|
||||
iconURL: conversation?.iconURL,
|
||||
endpoint: ep,
|
||||
});
|
||||
}, [conversation?.endpoint, conversation?.iconURL, endpointsConfig]);
|
||||
|
||||
const { data: documentsMap = new Map() } = useGetAssistantDocsQuery(endpointType, {
|
||||
select: (data) => new Map(data.map((dbA) => [dbA.assistant_id, dbA])),
|
||||
});
|
||||
|
||||
const { entity, isAgent } = getEntity({
|
||||
endpoint: endpointType,
|
||||
agentsMap,
|
||||
assistantMap,
|
||||
agent_id: conversation?.agent_id,
|
||||
assistant_id: conversation?.assistant_id,
|
||||
});
|
||||
|
||||
const conversation_starters = useMemo(() => {
|
||||
if (entity?.conversation_starters?.length) {
|
||||
return entity.conversation_starters;
|
||||
}
|
||||
|
||||
if (isAgent) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return documentsMap.get(entity?.id ?? '')?.conversation_starters ?? [];
|
||||
}, [documentsMap, isAgent, entity]);
|
||||
|
||||
const { submitMessage } = useSubmitMessage();
|
||||
const sendConversationStarter = useCallback(
|
||||
(text: string) => submitMessage({ text }),
|
||||
[submitMessage],
|
||||
);
|
||||
|
||||
if (!conversation_starters.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-8 flex flex-wrap justify-center gap-3 px-4">
|
||||
{conversation_starters
|
||||
.slice(0, Constants.MAX_CONVO_STARTERS)
|
||||
.map((text: string, index: number) => (
|
||||
<button
|
||||
key={index}
|
||||
onClick={() => sendConversationStarter(text)}
|
||||
className="relative flex w-40 cursor-pointer flex-col gap-2 rounded-2xl border border-border-medium px-3 pb-4 pt-3 text-start align-top text-[15px] shadow-[0_0_2px_0_rgba(0,0,0,0.05),0_4px_6px_0_rgba(0,0,0,0.02)] transition-colors duration-300 ease-in-out fade-in hover:bg-surface-tertiary"
|
||||
>
|
||||
<p className="break-word line-clamp-3 overflow-hidden text-balance break-all text-text-secondary">
|
||||
{text}
|
||||
</p>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ConversationStarters;
|
||||
87
client/src/components/Chat/Input/EditBadges.tsx
Normal file
87
client/src/components/Chat/Input/EditBadges.tsx
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import React, { useCallback } from 'react';
|
||||
import { Edit3, Check, X } from 'lucide-react';
|
||||
import type { LucideIcon } from 'lucide-react';
|
||||
import type { BadgeItem } from '~/common';
|
||||
import { useChatBadges, useLocalize } from '~/hooks';
|
||||
import { Button, Badge } from '~/components/ui';
|
||||
|
||||
interface EditBadgesProps {
|
||||
isEditingChatBadges: boolean;
|
||||
handleCancelBadges: () => void;
|
||||
handleSaveBadges: () => void;
|
||||
setBadges: React.Dispatch<React.SetStateAction<Pick<BadgeItem, 'id'>[]>>;
|
||||
}
|
||||
|
||||
const EditBadgesComponent = ({
|
||||
isEditingChatBadges,
|
||||
handleCancelBadges,
|
||||
handleSaveBadges,
|
||||
setBadges,
|
||||
}: EditBadgesProps) => {
|
||||
const localize = useLocalize();
|
||||
const allBadges = useChatBadges() || [];
|
||||
const unavailableBadges = allBadges.filter((badge) => !badge.isAvailable);
|
||||
|
||||
const handleRestoreBadge = useCallback(
|
||||
(badgeId: string) => {
|
||||
setBadges((prev: Pick<BadgeItem, 'id'>[]) => [...prev, { id: badgeId }]);
|
||||
},
|
||||
[setBadges],
|
||||
);
|
||||
|
||||
if (!isEditingChatBadges) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="m-1.5 flex flex-col overflow-hidden rounded-b-lg rounded-t-2xl bg-surface-secondary-alt">
|
||||
<div className="flex items-center gap-4 py-2 pl-3 pr-1.5 text-sm">
|
||||
<span className="mt-0 flex size-6 flex-shrink-0 items-center justify-center">
|
||||
<div className="icon-md">
|
||||
<Edit3 className="icon-md" aria-hidden="true" />
|
||||
</div>
|
||||
</span>
|
||||
<span className="text-token-text-secondary line-clamp-3 flex-1 py-0.5 font-semibold">
|
||||
{localize('com_ui_save_badge_changes')}
|
||||
</span>
|
||||
<div className="flex h-8 gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="destructive"
|
||||
aria-label="Cancel"
|
||||
onClick={handleCancelBadges}
|
||||
className="h-8"
|
||||
>
|
||||
<X className="icon-md" aria-hidden="true" />
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="submit"
|
||||
aria-label="Save changes"
|
||||
onClick={handleSaveBadges}
|
||||
className="h-8 rounded-b-lg rounded-tr-xl"
|
||||
>
|
||||
<Check className="icon-md" aria-hidden="true" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{unavailableBadges && unavailableBadges.length > 0 && (
|
||||
<div className="flex flex-wrap items-center gap-2 p-2">
|
||||
{unavailableBadges.map((badge) => (
|
||||
<div key={badge.id} className="badge-icon">
|
||||
<Badge
|
||||
icon={badge.icon as unknown as LucideIcon}
|
||||
label={badge.label}
|
||||
isAvailable={false}
|
||||
isEditing={true}
|
||||
onBadgeAction={() => handleRestoreBadge(badge.id)}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default React.memo(EditBadgesComponent);
|
||||
|
|
@ -1,55 +1,52 @@
|
|||
import React, { useRef } from 'react';
|
||||
import { FileUpload, TooltipAnchor } from '~/components/ui';
|
||||
import { AttachmentIcon } from '~/components/svg';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { FileUpload, TooltipAnchor, AttachmentIcon } from '~/components';
|
||||
import { useLocalize, useFileHandling } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
const AttachFile = ({
|
||||
isRTL,
|
||||
disabled,
|
||||
handleFileChange,
|
||||
}: {
|
||||
isRTL: boolean;
|
||||
disabled?: boolean | null;
|
||||
handleFileChange: (event: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
}) => {
|
||||
const AttachFile = ({ disabled }: { disabled?: boolean | null }) => {
|
||||
const localize = useLocalize();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
|
||||
const { handleFileChange } = useFileHandling();
|
||||
|
||||
return (
|
||||
<FileUpload ref={inputRef} handleFileChange={handleFileChange}>
|
||||
<TooltipAnchor
|
||||
role="button"
|
||||
id="attach-file"
|
||||
aria-label={localize('com_sidepanel_attach_files')}
|
||||
disabled={isUploadDisabled}
|
||||
className={cn(
|
||||
'absolute flex size-[35px] items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
isRTL ? 'bottom-2 right-2' : 'bottom-2 left-2',
|
||||
)}
|
||||
description={localize('com_sidepanel_attach_files')}
|
||||
onKeyDownCapture={(e) => {
|
||||
if (!inputRef.current) {
|
||||
return;
|
||||
}
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
inputRef.current.value = '';
|
||||
inputRef.current.click();
|
||||
}
|
||||
}}
|
||||
onClick={() => {
|
||||
if (!inputRef.current) {
|
||||
return;
|
||||
}
|
||||
inputRef.current.value = '';
|
||||
inputRef.current.click();
|
||||
}}
|
||||
>
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
<AttachmentIcon />
|
||||
</div>
|
||||
</TooltipAnchor>
|
||||
id="attach-file"
|
||||
disabled={isUploadDisabled}
|
||||
render={
|
||||
<button
|
||||
type="button"
|
||||
aria-label={localize('com_sidepanel_attach_files')}
|
||||
disabled={isUploadDisabled}
|
||||
className={cn(
|
||||
'flex size-9 items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
)}
|
||||
onKeyDownCapture={(e) => {
|
||||
if (!inputRef.current) {
|
||||
return;
|
||||
}
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
inputRef.current.value = '';
|
||||
inputRef.current.click();
|
||||
}
|
||||
}}
|
||||
onClick={() => {
|
||||
if (!inputRef.current) {
|
||||
return;
|
||||
}
|
||||
inputRef.current.value = '';
|
||||
inputRef.current.click();
|
||||
}}
|
||||
>
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
<AttachmentIcon />
|
||||
</div>
|
||||
</button>
|
||||
}
|
||||
/>
|
||||
</FileUpload>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
44
client/src/components/Chat/Input/Files/AttachFileChat.tsx
Normal file
44
client/src/components/Chat/Input/Files/AttachFileChat.tsx
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import { memo, useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import {
|
||||
supportsFiles,
|
||||
mergeFileConfig,
|
||||
isAgentsEndpoint,
|
||||
EndpointFileConfig,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import AttachFileMenu from './AttachFileMenu';
|
||||
import AttachFile from './AttachFile';
|
||||
import store from '~/store';
|
||||
|
||||
function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const { conversation } = useChatContext();
|
||||
|
||||
const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
|
||||
const isAgents = useMemo(() => isAgentsEndpoint(_endpoint), [_endpoint]);
|
||||
|
||||
const { data: fileConfig = defaultFileConfig } = useGetFileConfig({
|
||||
select: (data) => mergeFileConfig(data),
|
||||
});
|
||||
|
||||
const endpointFileConfig = fileConfig.endpoints[_endpoint ?? ''] as
|
||||
| EndpointFileConfig
|
||||
| undefined;
|
||||
|
||||
const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? _endpoint ?? ''] ?? false;
|
||||
const isUploadDisabled = (disableInputs || endpointFileConfig?.disabled) ?? false;
|
||||
|
||||
if (isAgents) {
|
||||
return <AttachFileMenu disabled={disableInputs} />;
|
||||
}
|
||||
if (endpointSupportsFiles && !isUploadDisabled) {
|
||||
return <AttachFile disabled={disableInputs} />;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
export default memo(AttachFileChat);
|
||||
|
|
@ -2,25 +2,23 @@ import * as Ariakit from '@ariakit/react';
|
|||
import React, { useRef, useState, useMemo } from 'react';
|
||||
import { EToolResources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react';
|
||||
import { FileUpload, TooltipAnchor, DropdownPopup } from '~/components/ui';
|
||||
import { FileUpload, TooltipAnchor, DropdownPopup, AttachmentIcon } from '~/components';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { AttachmentIcon } from '~/components/svg';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { useLocalize, useFileHandling } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface AttachFileProps {
|
||||
isRTL: boolean;
|
||||
disabled?: boolean | null;
|
||||
handleFileChange: (event: React.ChangeEvent<HTMLInputElement>, toolResource?: string) => void;
|
||||
}
|
||||
|
||||
const AttachFile = ({ isRTL, disabled, handleFileChange }: AttachFileProps) => {
|
||||
const AttachFile = ({ disabled }: AttachFileProps) => {
|
||||
const localize = useLocalize();
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [isPopoverActive, setIsPopoverActive] = useState(false);
|
||||
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const { handleFileChange } = useFileHandling();
|
||||
|
||||
const capabilities = useMemo(
|
||||
() => endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [],
|
||||
|
|
@ -93,8 +91,7 @@ const AttachFile = ({ isRTL, disabled, handleFileChange }: AttachFileProps) => {
|
|||
id="attach-file-menu-button"
|
||||
aria-label="Attach File Options"
|
||||
className={cn(
|
||||
'absolute flex size-[35px] items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
isRTL ? 'bottom-2 right-2' : 'bottom-2 left-1 md:left-2',
|
||||
'flex size-9 items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
|
|
@ -115,17 +112,15 @@ const AttachFile = ({ isRTL, disabled, handleFileChange }: AttachFileProps) => {
|
|||
handleFileChange(e, toolResource);
|
||||
}}
|
||||
>
|
||||
<div className="relative select-none">
|
||||
<DropdownPopup
|
||||
menuId="attach-file-menu"
|
||||
isOpen={isPopoverActive}
|
||||
setIsOpen={setIsPopoverActive}
|
||||
modal={true}
|
||||
trigger={menuTrigger}
|
||||
items={dropdownItems}
|
||||
iconClassName="mr-0"
|
||||
/>
|
||||
</div>
|
||||
<DropdownPopup
|
||||
menuId="attach-file-menu"
|
||||
isOpen={isPopoverActive}
|
||||
setIsOpen={setIsPopoverActive}
|
||||
modal={true}
|
||||
trigger={menuTrigger}
|
||||
items={dropdownItems}
|
||||
iconClassName="mr-0"
|
||||
/>
|
||||
</FileUpload>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
30
client/src/components/Chat/Input/Files/FileFormChat.tsx
Normal file
30
client/src/components/Chat/Input/Files/FileFormChat.tsx
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import { memo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { useFileHandling } from '~/hooks';
|
||||
import FileRow from './FileRow';
|
||||
import store from '~/store';
|
||||
|
||||
function FileFormChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
|
||||
const { files, setFiles, conversation, setFilesLoading } = useChatContext();
|
||||
const { endpoint: _endpoint } = conversation ?? { endpoint: null };
|
||||
const { abortUpload } = useFileHandling();
|
||||
|
||||
const isRTL = chatDirection === 'rtl';
|
||||
|
||||
return (
|
||||
<>
|
||||
<FileRow
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
abortUpload={abortUpload}
|
||||
setFilesLoading={setFilesLoading}
|
||||
isRTL={isRTL}
|
||||
Wrapper={({ children }) => <div className="mx-2 mt-2 flex flex-wrap gap-2">{children}</div>}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(FileFormChat);
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
import { memo, useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import {
|
||||
supportsFiles,
|
||||
mergeFileConfig,
|
||||
isAgentsEndpoint,
|
||||
EndpointFileConfig,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import AttachFileMenu from './AttachFileMenu';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { useFileHandling } from '~/hooks';
|
||||
import AttachFile from './AttachFile';
|
||||
import FileRow from './FileRow';
|
||||
import store from '~/store';
|
||||
|
||||
function FileFormWrapper({
|
||||
children,
|
||||
disableInputs,
|
||||
}: {
|
||||
disableInputs: boolean;
|
||||
children?: React.ReactNode;
|
||||
}) {
|
||||
const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
|
||||
const { files, setFiles, conversation, setFilesLoading } = useChatContext();
|
||||
const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
const isAgents = useMemo(() => isAgentsEndpoint(_endpoint), [_endpoint]);
|
||||
|
||||
const { handleFileChange, abortUpload } = useFileHandling();
|
||||
|
||||
const { data: fileConfig = defaultFileConfig } = useGetFileConfig({
|
||||
select: (data) => mergeFileConfig(data),
|
||||
});
|
||||
|
||||
const isRTL = chatDirection === 'rtl';
|
||||
|
||||
const endpointFileConfig = fileConfig.endpoints[_endpoint ?? ''] as
|
||||
| EndpointFileConfig
|
||||
| undefined;
|
||||
|
||||
const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? _endpoint ?? ''] ?? false;
|
||||
const isUploadDisabled = (disableInputs || endpointFileConfig?.disabled) ?? false;
|
||||
|
||||
const renderAttachFile = () => {
|
||||
if (isAgents) {
|
||||
return (
|
||||
<AttachFileMenu
|
||||
isRTL={isRTL}
|
||||
disabled={disableInputs}
|
||||
handleFileChange={handleFileChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (endpointSupportsFiles && !isUploadDisabled) {
|
||||
return (
|
||||
<AttachFile isRTL={isRTL} disabled={disableInputs} handleFileChange={handleFileChange} />
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<FileRow
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
abortUpload={abortUpload}
|
||||
setFilesLoading={setFilesLoading}
|
||||
isRTL={isRTL}
|
||||
Wrapper={({ children }) => <div className="mx-2 mt-2 flex flex-wrap gap-2">{children}</div>}
|
||||
/>
|
||||
{children}
|
||||
{renderAttachFile()}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(FileFormWrapper);
|
||||
|
|
@ -1,11 +1,15 @@
|
|||
import { useLocalize } from '~/hooks';
|
||||
|
||||
export default function RemoveFile({ onRemove }: { onRemove: () => void }) {
|
||||
const localize = useLocalize();
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
className="absolute right-1 top-1 -translate-y-1/2 translate-x-1/2 rounded-full bg-surface-secondary p-0.5 transition-colors duration-200 hover:bg-surface-primary z-50"
|
||||
className="absolute right-1 top-1 -translate-y-1/2 translate-x-1/2 rounded-full bg-surface-secondary p-0.5 transition-colors duration-200 hover:bg-surface-primary"
|
||||
onClick={onRemove}
|
||||
aria-label={localize('com_ui_attach_remove')}
|
||||
>
|
||||
<span>
|
||||
<span aria-hidden="true">
|
||||
<svg
|
||||
stroke="currentColor"
|
||||
fill="none"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ const sourceToEndpoint = {
|
|||
|
||||
const sourceToClassname = {
|
||||
[FileSources.openai]: 'bg-white/75 dark:bg-black/65',
|
||||
[FileSources.azure]: 'azure-bg-color opacity-85',
|
||||
[FileSources.azure]: 'azure-bg-color',
|
||||
[FileSources.azure_blob]: 'azure-bg-color',
|
||||
[FileSources.execute_code]: 'bg-black text-white opacity-85',
|
||||
[FileSources.text]: 'bg-blue-500 dark:bg-blue-900 opacity-85 text-white',
|
||||
[FileSources.vectordb]: 'bg-yellow-700 dark:bg-yellow-900 opacity-85 text-white',
|
||||
|
|
|
|||
|
|
@ -2,17 +2,12 @@ import { useRecoilState } from 'recoil';
|
|||
import { Settings2 } from 'lucide-react';
|
||||
import { useState, useEffect, useMemo } from 'react';
|
||||
import { Root, Anchor } from '@radix-ui/react-popover';
|
||||
import {
|
||||
EModelEndpoint,
|
||||
isParamEndpoint,
|
||||
isAgentsEndpoint,
|
||||
tConvoUpdateSchema,
|
||||
} from 'librechat-data-provider';
|
||||
import { EModelEndpoint, isParamEndpoint, tConvoUpdateSchema } from 'librechat-data-provider';
|
||||
import { useUserKeyQuery } from 'librechat-data-provider/react-query';
|
||||
import type { TPreset, TInterfaceConfig } from 'librechat-data-provider';
|
||||
import { EndpointSettings, SaveAsPresetDialog, AlternativeSettings } from '~/components/Endpoints';
|
||||
import { useSetIndexOptions, useMediaQuery, useLocalize } from '~/hooks';
|
||||
import { PluginStoreDialog, TooltipAnchor } from '~/components';
|
||||
import { ModelSelect } from '~/components/Input/ModelSelect';
|
||||
import { useSetIndexOptions, useLocalize } from '~/hooks';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import OptionsPopover from './OptionsPopover';
|
||||
import PopoverButtons from './PopoverButtons';
|
||||
|
|
@ -26,6 +21,7 @@ export default function HeaderOptions({
|
|||
interfaceConfig?: Partial<TInterfaceConfig>;
|
||||
}) {
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
|
||||
const [saveAsDialogShow, setSaveAsDialogShow] = useState<boolean>(false);
|
||||
const [showPluginStoreDialog, setShowPluginStoreDialog] = useRecoilState(
|
||||
store.showPluginStoreDialog,
|
||||
|
|
@ -35,6 +31,15 @@ export default function HeaderOptions({
|
|||
const { showPopover, conversation, setShowPopover } = useChatContext();
|
||||
const { setOption } = useSetIndexOptions();
|
||||
const { endpoint, conversationId } = conversation ?? {};
|
||||
const { data: keyExpiry = { expiresAt: undefined } } = useUserKeyQuery(endpoint ?? '');
|
||||
const userProvidesKey = useMemo(
|
||||
() => !!(endpointsConfig?.[endpoint ?? '']?.userProvide ?? false),
|
||||
[endpointsConfig, endpoint],
|
||||
);
|
||||
const keyProvided = useMemo(
|
||||
() => (userProvidesKey ? !!(keyExpiry.expiresAt ?? '') : true),
|
||||
[keyExpiry.expiresAt, userProvidesKey],
|
||||
);
|
||||
|
||||
const noSettings = useMemo<{ [key: string]: boolean }>(
|
||||
() => ({
|
||||
|
|
@ -71,14 +76,6 @@ export default function HeaderOptions({
|
|||
<div className="my-auto lg:max-w-2xl xl:max-w-3xl">
|
||||
<span className="flex w-full flex-col items-center justify-center gap-0 md:order-none md:m-auto md:gap-2">
|
||||
<div className="z-[61] flex w-full items-center justify-center gap-2">
|
||||
{interfaceConfig?.modelSelect === true && !isAgentsEndpoint(endpoint) && (
|
||||
<ModelSelect
|
||||
conversation={conversation}
|
||||
setOption={setOption}
|
||||
showAbove={false}
|
||||
popover={true}
|
||||
/>
|
||||
)}
|
||||
{!noSettings[endpoint] &&
|
||||
interfaceConfig?.parameters === true &&
|
||||
paramEndpoint === false && (
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ export default function Mention({
|
|||
includeAssistants?: boolean;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
const assistantsMap = useAssistantsMapContext();
|
||||
const {
|
||||
options,
|
||||
presets,
|
||||
|
|
@ -37,11 +37,11 @@ export default function Mention({
|
|||
modelsConfig,
|
||||
endpointsConfig,
|
||||
assistantListMap,
|
||||
} = useMentions({ assistantMap: assistantMap || {}, includeAssistants });
|
||||
} = useMentions({ assistantMap: assistantsMap || {}, includeAssistants });
|
||||
const { onSelectMention } = useSelectMention({
|
||||
presets,
|
||||
modelSpecs,
|
||||
assistantMap,
|
||||
assistantsMap,
|
||||
endpointsConfig,
|
||||
newConversation,
|
||||
});
|
||||
|
|
@ -65,7 +65,7 @@ export default function Mention({
|
|||
setSearchValue('');
|
||||
setOpen(false);
|
||||
setShowMentionPopover(false);
|
||||
onSelectMention(mention);
|
||||
onSelectMention?.(mention);
|
||||
|
||||
if (textAreaRef.current) {
|
||||
removeCharIfLast(textAreaRef.current, commandChar);
|
||||
|
|
@ -158,11 +158,11 @@ export default function Mention({
|
|||
};
|
||||
|
||||
return (
|
||||
<div className="absolute bottom-14 z-10 w-full space-y-2">
|
||||
<div className="absolute bottom-28 z-10 w-full space-y-2">
|
||||
<div className="popover border-token-border-light rounded-2xl border bg-white p-2 shadow-lg dark:bg-gray-700">
|
||||
<input
|
||||
// The user expects focus to transition to the input field when the popover is opened
|
||||
|
||||
// eslint-disable-next-line jsx-a11y/no-autofocus
|
||||
autoFocus
|
||||
ref={inputRef}
|
||||
placeholder={localize(placeholder)}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,9 @@ function PromptsCommand({
|
|||
label: `${group.command != null && group.command ? `/${group.command} - ` : ''}${
|
||||
group.name
|
||||
}: ${
|
||||
(group.oneliner?.length ?? 0) > 0 ? group.oneliner : group.productionPrompt?.prompt ?? ''
|
||||
(group.oneliner?.length ?? 0) > 0
|
||||
? group.oneliner
|
||||
: (group.productionPrompt?.prompt ?? '')
|
||||
}`,
|
||||
icon: <CategoryIcon category={group.category ?? ''} className="h-5 w-5" />,
|
||||
}));
|
||||
|
|
@ -195,11 +197,11 @@ function PromptsCommand({
|
|||
variableGroup={variableGroup}
|
||||
setVariableDialogOpen={setVariableDialogOpen}
|
||||
>
|
||||
<div className="absolute bottom-14 z-10 w-full space-y-2">
|
||||
<div className="absolute bottom-28 z-10 w-full space-y-2">
|
||||
<div className="popover border-token-border-light rounded-2xl border bg-surface-tertiary-alt p-2 shadow-lg">
|
||||
<input
|
||||
// The user expects focus to transition to the input field when the popover is opened
|
||||
// eslint-disable-next-line jsx-a11y/no-autofocus
|
||||
|
||||
autoFocus
|
||||
ref={inputRef}
|
||||
placeholder={localize('com_ui_command_usage_placeholder')}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ const SubmitButton = React.memo(
|
|||
id="send-button"
|
||||
disabled={props.disabled}
|
||||
className={cn(
|
||||
'rounded-full bg-text-primary p-2 text-text-primary outline-offset-4 transition-all duration-200 disabled:cursor-not-allowed disabled:text-text-secondary disabled:opacity-10',
|
||||
'rounded-full bg-text-primary p-1.5 text-text-primary outline-offset-4 transition-all duration-200 disabled:cursor-not-allowed disabled:text-text-secondary disabled:opacity-10',
|
||||
)}
|
||||
data-testid="send-button"
|
||||
type="submit"
|
||||
|
|
@ -34,7 +34,7 @@ const SubmitButton = React.memo(
|
|||
</span>
|
||||
</button>
|
||||
}
|
||||
></TooltipAnchor>
|
||||
/>
|
||||
);
|
||||
}),
|
||||
);
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ export default function StopButton({ stop, setShowStopButton }) {
|
|||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
'rounded-full bg-text-primary p-2 text-text-primary outline-offset-4 transition-all duration-200 disabled:cursor-not-allowed disabled:text-text-secondary disabled:opacity-10',
|
||||
'rounded-full bg-text-primary p-1.5 text-text-primary outline-offset-4 transition-all duration-200 disabled:cursor-not-allowed disabled:text-text-secondary disabled:opacity-10',
|
||||
)}
|
||||
aria-label={localize('com_nav_stop_generating')}
|
||||
onClick={(e) => {
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
import { MessageCircleDashed, X } from 'lucide-react';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface TemporaryChatProps {
|
||||
isTemporaryChat: boolean;
|
||||
setIsTemporaryChat: (value: boolean) => void;
|
||||
}
|
||||
|
||||
export const TemporaryChat = ({ isTemporaryChat, setIsTemporaryChat }: TemporaryChatProps) => {
|
||||
const localize = useLocalize();
|
||||
|
||||
if (!isTemporaryChat) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="divide-token-border-light m-1.5 flex flex-col divide-y overflow-hidden rounded-b-lg rounded-t-2xl bg-surface-secondary-alt">
|
||||
<div className="flex items-start gap-4 py-2.5 pl-3 pr-1.5 text-sm">
|
||||
<span className="mt-0 flex h-6 w-6 flex-shrink-0 items-center justify-center">
|
||||
<div className="icon-md">
|
||||
<MessageCircleDashed className="icon-md" aria-hidden="true" />
|
||||
</div>
|
||||
</span>
|
||||
<span className="text-token-text-secondary line-clamp-3 flex-1 py-0.5 font-semibold">
|
||||
{localize('com_ui_temporary_chat')}
|
||||
</span>
|
||||
<button
|
||||
className="text-token-text-secondary flex-shrink-0"
|
||||
type="button"
|
||||
aria-label="Close temporary chat"
|
||||
onClick={() => setIsTemporaryChat(false)}
|
||||
>
|
||||
<X className="pr-1" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -13,7 +13,7 @@ export default function TextareaHeader({
|
|||
return null;
|
||||
}
|
||||
return (
|
||||
<div className="divide-token-border-light m-1.5 flex flex-col divide-y overflow-hidden rounded-b-lg rounded-t-2xl bg-surface-secondary-alt">
|
||||
<div className="m-1.5 flex flex-col divide-y overflow-hidden rounded-b-lg rounded-t-2xl bg-surface-secondary-alt">
|
||||
<AddedConvo addedConvo={addedConvo} setAddedConvo={setAddedConvo} />
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,47 +1,50 @@
|
|||
import { useMemo } from 'react';
|
||||
import { EModelEndpoint, Constants } from 'librechat-data-provider';
|
||||
import type * as t from 'librechat-data-provider';
|
||||
import type { ReactNode } from 'react';
|
||||
import { useMemo, useCallback, useState, useEffect, useRef } from 'react';
|
||||
import { easings } from '@react-spring/web';
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
import { useChatContext, useAgentsMapContext, useAssistantsMapContext } from '~/Providers';
|
||||
import {
|
||||
useGetAssistantDocsQuery,
|
||||
useGetEndpointsQuery,
|
||||
useGetStartupConfig,
|
||||
} from '~/data-provider';
|
||||
import { useGetEndpointsQuery, useGetStartupConfig } from '~/data-provider';
|
||||
import { BirthdayIcon, TooltipAnchor, SplitText } from '~/components';
|
||||
import ConvoIcon from '~/components/Endpoints/ConvoIcon';
|
||||
import { getIconEndpoint, getEntity, cn } from '~/utils';
|
||||
import { useLocalize, useSubmitMessage } from '~/hooks';
|
||||
import { TooltipAnchor } from '~/components/ui';
|
||||
import { BirthdayIcon } from '~/components/svg';
|
||||
import ConvoStarter from './ConvoStarter';
|
||||
import { useLocalize, useAuthContext } from '~/hooks';
|
||||
import { getIconEndpoint, getEntity } from '~/utils';
|
||||
|
||||
export default function Landing({ Header }: { Header?: ReactNode }) {
|
||||
const containerClassName =
|
||||
'shadow-stroke relative flex h-full items-center justify-center rounded-full bg-white text-black';
|
||||
|
||||
export default function Landing({ centerFormOnLanding }: { centerFormOnLanding: boolean }) {
|
||||
const { conversation } = useChatContext();
|
||||
const agentsMap = useAgentsMapContext();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
|
||||
const { user } = useAuthContext();
|
||||
const localize = useLocalize();
|
||||
|
||||
let { endpoint = '' } = conversation ?? {};
|
||||
const [textHasMultipleLines, setTextHasMultipleLines] = useState(false);
|
||||
const [lineCount, setLineCount] = useState(1);
|
||||
const [contentHeight, setContentHeight] = useState(0);
|
||||
const contentRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
if (
|
||||
endpoint === EModelEndpoint.chatGPTBrowser ||
|
||||
endpoint === EModelEndpoint.azureOpenAI ||
|
||||
endpoint === EModelEndpoint.gptPlugins
|
||||
) {
|
||||
endpoint = EModelEndpoint.openAI;
|
||||
}
|
||||
|
||||
const iconURL = conversation?.iconURL;
|
||||
endpoint = getIconEndpoint({ endpointsConfig, iconURL, endpoint });
|
||||
const { data: documentsMap = new Map() } = useGetAssistantDocsQuery(endpoint, {
|
||||
select: (data) => new Map(data.map((dbA) => [dbA.assistant_id, dbA])),
|
||||
});
|
||||
const endpointType = useMemo(() => {
|
||||
let ep = conversation?.endpoint ?? '';
|
||||
if (
|
||||
[
|
||||
EModelEndpoint.chatGPTBrowser,
|
||||
EModelEndpoint.azureOpenAI,
|
||||
EModelEndpoint.gptPlugins,
|
||||
].includes(ep as EModelEndpoint)
|
||||
) {
|
||||
ep = EModelEndpoint.openAI;
|
||||
}
|
||||
return getIconEndpoint({
|
||||
endpointsConfig,
|
||||
iconURL: conversation?.iconURL,
|
||||
endpoint: ep,
|
||||
});
|
||||
}, [conversation?.endpoint, conversation?.iconURL, endpointsConfig]);
|
||||
|
||||
const { entity, isAgent, isAssistant } = getEntity({
|
||||
endpoint,
|
||||
endpoint: endpointType,
|
||||
agentsMap,
|
||||
assistantMap,
|
||||
agent_id: conversation?.agent_id,
|
||||
|
|
@ -50,102 +53,144 @@ export default function Landing({ Header }: { Header?: ReactNode }) {
|
|||
|
||||
const name = entity?.name ?? '';
|
||||
const description = entity?.description ?? '';
|
||||
const avatar = isAgent
|
||||
? (entity as t.Agent | undefined)?.avatar?.filepath ?? ''
|
||||
: ((entity as t.Assistant | undefined)?.metadata?.avatar as string | undefined) ?? '';
|
||||
const conversation_starters = useMemo(() => {
|
||||
/* The user made updates, use client-side cache, or they exist in an Agent */
|
||||
if (entity && (entity.conversation_starters?.length ?? 0) > 0) {
|
||||
return entity.conversation_starters;
|
||||
}
|
||||
if (isAgent) {
|
||||
return entity?.conversation_starters ?? [];
|
||||
|
||||
const getGreeting = useCallback(() => {
|
||||
if (typeof startupConfig?.interface?.customWelcome === 'string') {
|
||||
const customWelcome = startupConfig.interface.customWelcome;
|
||||
// Replace {{user.name}} with actual user name if available
|
||||
if (user?.name && customWelcome.includes('{{user.name}}')) {
|
||||
return customWelcome.replace(/{{user.name}}/g, user.name);
|
||||
}
|
||||
return customWelcome;
|
||||
}
|
||||
|
||||
/* If none in cache, we use the latest assistant docs */
|
||||
const entityDocs = documentsMap.get(entity?.id ?? '');
|
||||
return entityDocs?.conversation_starters ?? [];
|
||||
}, [documentsMap, isAgent, entity]);
|
||||
const now = new Date();
|
||||
const hours = now.getHours();
|
||||
|
||||
const containerClassName =
|
||||
'shadow-stroke relative flex h-full items-center justify-center rounded-full bg-white text-black';
|
||||
const dayOfWeek = now.getDay();
|
||||
const isWeekend = dayOfWeek === 0 || dayOfWeek === 6;
|
||||
|
||||
const { submitMessage } = useSubmitMessage();
|
||||
const sendConversationStarter = (text: string) => submitMessage({ text });
|
||||
// Early morning (midnight to 4:59 AM)
|
||||
if (hours >= 0 && hours < 5) {
|
||||
return localize('com_ui_late_night');
|
||||
}
|
||||
// Morning (6 AM to 11:59 AM)
|
||||
else if (hours < 12) {
|
||||
if (isWeekend) {
|
||||
return localize('com_ui_weekend_morning');
|
||||
}
|
||||
return localize('com_ui_good_morning');
|
||||
}
|
||||
// Afternoon (12 PM to 4:59 PM)
|
||||
else if (hours < 17) {
|
||||
return localize('com_ui_good_afternoon');
|
||||
}
|
||||
// Evening (5 PM to 8:59 PM)
|
||||
else {
|
||||
return localize('com_ui_good_evening');
|
||||
}
|
||||
}, [localize, startupConfig?.interface?.customWelcome, user?.name]);
|
||||
|
||||
const getWelcomeMessage = () => {
|
||||
const greeting = conversation?.greeting ?? '';
|
||||
if (greeting) {
|
||||
return greeting;
|
||||
const handleLineCountChange = useCallback((count: number) => {
|
||||
setTextHasMultipleLines(count > 1);
|
||||
setLineCount(count);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (contentRef.current) {
|
||||
setContentHeight(contentRef.current.offsetHeight);
|
||||
}
|
||||
}, [lineCount, description]);
|
||||
|
||||
const getDynamicMargin = useMemo(() => {
|
||||
let margin = 'mb-0';
|
||||
|
||||
if (lineCount > 2 || (description && description.length > 100)) {
|
||||
margin = 'mb-10';
|
||||
} else if (lineCount > 1 || (description && description.length > 0)) {
|
||||
margin = 'mb-6';
|
||||
} else if (textHasMultipleLines) {
|
||||
margin = 'mb-4';
|
||||
}
|
||||
|
||||
if (isAssistant) {
|
||||
return localize('com_nav_welcome_assistant');
|
||||
if (contentHeight > 200) {
|
||||
margin = 'mb-16';
|
||||
} else if (contentHeight > 150) {
|
||||
margin = 'mb-12';
|
||||
}
|
||||
|
||||
if (isAgent) {
|
||||
return localize('com_nav_welcome_agent');
|
||||
}
|
||||
|
||||
return typeof startupConfig?.interface?.customWelcome === 'string'
|
||||
? startupConfig?.interface?.customWelcome
|
||||
: localize('com_nav_welcome_message');
|
||||
};
|
||||
return margin;
|
||||
}, [lineCount, description, textHasMultipleLines, contentHeight]);
|
||||
|
||||
return (
|
||||
<div className="relative h-full">
|
||||
<div className="absolute left-0 right-0">{Header != null ? Header : null}</div>
|
||||
<div className="flex h-full flex-col items-center justify-center">
|
||||
<div className={cn('relative h-12 w-12', name && avatar ? 'mb-0' : 'mb-3')}>
|
||||
<ConvoIcon
|
||||
agentsMap={agentsMap}
|
||||
assistantMap={assistantMap}
|
||||
conversation={conversation}
|
||||
endpointsConfig={endpointsConfig}
|
||||
containerClassName={containerClassName}
|
||||
context="landing"
|
||||
className="h-2/3 w-2/3"
|
||||
size={41}
|
||||
/>
|
||||
{startupConfig?.showBirthdayIcon === true ? (
|
||||
<TooltipAnchor
|
||||
className="absolute bottom-8 right-2.5"
|
||||
description={localize('com_ui_happy_birthday')}
|
||||
>
|
||||
<BirthdayIcon />
|
||||
</TooltipAnchor>
|
||||
) : null}
|
||||
</div>
|
||||
{name ? (
|
||||
<div className="flex flex-col items-center gap-0 p-2">
|
||||
<div className="text-center text-2xl font-medium dark:text-white">{name}</div>
|
||||
<div className="max-w-md text-center text-sm font-normal text-text-primary ">
|
||||
{description ||
|
||||
(typeof startupConfig?.interface?.customWelcome === 'string'
|
||||
? startupConfig?.interface?.customWelcome
|
||||
: localize('com_nav_welcome_message'))}
|
||||
</div>
|
||||
{/* <div className="mt-1 flex items-center gap-1 text-token-text-tertiary">
|
||||
<div className="text-sm text-token-text-tertiary">By Daniel Avila</div>
|
||||
</div> */}
|
||||
<div
|
||||
className={`flex h-full transform-gpu flex-col items-center justify-center pb-16 transition-all duration-200 ${centerFormOnLanding ? 'max-h-full sm:max-h-0' : 'max-h-full'} ${getDynamicMargin}`}
|
||||
>
|
||||
<div ref={contentRef} className="flex flex-col items-center gap-0 p-2">
|
||||
<div
|
||||
className={`flex ${textHasMultipleLines ? 'flex-col' : 'flex-col md:flex-row'} items-center justify-center gap-4`}
|
||||
>
|
||||
<div className={`relative size-10 justify-center ${textHasMultipleLines ? 'mb-2' : ''}`}>
|
||||
<ConvoIcon
|
||||
agentsMap={agentsMap}
|
||||
assistantMap={assistantMap}
|
||||
conversation={conversation}
|
||||
endpointsConfig={endpointsConfig}
|
||||
containerClassName={containerClassName}
|
||||
context="landing"
|
||||
className="h-2/3 w-2/3"
|
||||
size={41}
|
||||
/>
|
||||
{startupConfig?.showBirthdayIcon && (
|
||||
<TooltipAnchor
|
||||
className="absolute bottom-[27px] right-2"
|
||||
description={localize('com_ui_happy_birthday')}
|
||||
>
|
||||
<BirthdayIcon />
|
||||
</TooltipAnchor>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<h2 className="mb-5 max-w-[75vh] px-12 text-center text-lg font-medium dark:text-white md:px-0 md:text-2xl">
|
||||
{getWelcomeMessage()}
|
||||
</h2>
|
||||
)}
|
||||
<div className="mt-8 flex flex-wrap justify-center gap-3 px-4">
|
||||
{conversation_starters.length > 0 &&
|
||||
conversation_starters
|
||||
.slice(0, Constants.MAX_CONVO_STARTERS)
|
||||
.map((text: string, index: number) => (
|
||||
<ConvoStarter
|
||||
key={index}
|
||||
text={text}
|
||||
onClick={() => sendConversationStarter(text)}
|
||||
/>
|
||||
))}
|
||||
{((isAgent || isAssistant) && name) || name ? (
|
||||
<div className="flex flex-col items-center gap-0 p-2">
|
||||
<SplitText
|
||||
key={`split-text-${name}`}
|
||||
text={name}
|
||||
className="text-4xl font-medium text-text-primary"
|
||||
delay={50}
|
||||
textAlign="center"
|
||||
animationFrom={{ opacity: 0, transform: 'translate3d(0,50px,0)' }}
|
||||
animationTo={{ opacity: 1, transform: 'translate3d(0,0,0)' }}
|
||||
easing={easings.easeOutCubic}
|
||||
threshold={0}
|
||||
rootMargin="0px"
|
||||
onLineCountChange={handleLineCountChange}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<SplitText
|
||||
key={`split-text-${getGreeting()}${user?.name ? '-user' : ''}`}
|
||||
text={
|
||||
typeof startupConfig?.interface?.customWelcome === 'string'
|
||||
? getGreeting()
|
||||
: getGreeting() + (user?.name ? ', ' + user.name : '')
|
||||
}
|
||||
className="text-2xl font-medium text-text-primary sm:text-4xl"
|
||||
delay={50}
|
||||
textAlign="center"
|
||||
animationFrom={{ opacity: 0, transform: 'translate3d(0,50px,0)' }}
|
||||
animationTo={{ opacity: 1, transform: 'translate3d(0,0,0)' }}
|
||||
easing={easings.easeOutCubic}
|
||||
threshold={0}
|
||||
rootMargin="0px"
|
||||
onLineCountChange={handleLineCountChange}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{(isAgent || isAssistant) && description && (
|
||||
<div className="animate-fadeIn mt-2 max-w-md text-center text-sm font-normal text-text-primary">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ const BookmarkMenu: FC = () => {
|
|||
id="bookmark-menu-button"
|
||||
aria-label={localize('com_ui_bookmarks_add')}
|
||||
className={cn(
|
||||
'mt-text-sm flex size-10 items-center justify-center gap-2 rounded-lg border border-border-light text-sm transition-colors duration-200 hover:bg-surface-hover',
|
||||
'mt-text-sm flex size-10 flex-shrink-0 items-center justify-center gap-2 rounded-lg border border-border-light text-sm transition-colors duration-200 hover:bg-surface-hover',
|
||||
isMenuOpen ? 'bg-surface-hover' : '',
|
||||
)}
|
||||
data-testid="bookmark-menu"
|
||||
|
|
|
|||
247
client/src/components/Chat/Menus/Endpoints/CustomMenu.tsx
Normal file
247
client/src/components/Chat/Menus/Endpoints/CustomMenu.tsx
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
import * as React from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export interface CustomMenuProps extends Ariakit.MenuButtonProps<'div'> {
|
||||
label?: React.ReactNode;
|
||||
values?: Record<string, any>;
|
||||
onValuesChange?: (values: Record<string, any>) => void;
|
||||
searchValue?: string;
|
||||
onSearch?: (value: string) => void;
|
||||
combobox?: Ariakit.ComboboxProps['render'];
|
||||
trigger?: Ariakit.MenuButtonProps['render'];
|
||||
defaultOpen?: boolean;
|
||||
}
|
||||
|
||||
export const CustomMenu = React.forwardRef<HTMLDivElement, CustomMenuProps>(function CustomMenu(
|
||||
{
|
||||
label,
|
||||
children,
|
||||
values,
|
||||
onValuesChange,
|
||||
searchValue,
|
||||
onSearch,
|
||||
combobox,
|
||||
trigger,
|
||||
defaultOpen,
|
||||
...props
|
||||
},
|
||||
ref,
|
||||
) {
|
||||
const parent = Ariakit.useMenuContext();
|
||||
const searchable = searchValue != null || !!onSearch || !!combobox;
|
||||
|
||||
const menuStore = Ariakit.useMenuStore({
|
||||
showTimeout: 100,
|
||||
placement: parent ? 'right' : 'left',
|
||||
defaultOpen: defaultOpen,
|
||||
});
|
||||
|
||||
const element = (
|
||||
<Ariakit.MenuProvider store={menuStore} values={values} setValues={onValuesChange}>
|
||||
<Ariakit.MenuButton
|
||||
ref={ref}
|
||||
{...props}
|
||||
className={cn(
|
||||
!parent &&
|
||||
'flex h-10 w-full items-center justify-center gap-2 rounded-xl border border-border-light px-3 py-2 text-sm text-text-primary focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-white',
|
||||
menuStore.useState('open')
|
||||
? 'bg-surface-tertiary hover:bg-surface-tertiary'
|
||||
: 'bg-surface-secondary hover:bg-surface-tertiary',
|
||||
props.className,
|
||||
)}
|
||||
render={parent ? <CustomMenuItem render={trigger} /> : trigger}
|
||||
>
|
||||
<span className="flex-1">{label}</span>
|
||||
<Ariakit.MenuButtonArrow className="stroke-1 text-base opacity-75" />
|
||||
</Ariakit.MenuButton>
|
||||
<Ariakit.Menu
|
||||
open={menuStore.useState('open')}
|
||||
portal
|
||||
overlap
|
||||
unmountOnHide
|
||||
gutter={parent ? -4 : 4}
|
||||
className={cn(
|
||||
`${parent ? 'animate-popover-left ml-3' : 'animate-popover'} outline-none! z-50 flex max-h-[min(450px,var(--popover-available-height))] w-full`,
|
||||
'w-[var(--menu-width,auto)] min-w-[300px] flex-col overflow-auto rounded-xl border border-border-light',
|
||||
'bg-surface-secondary px-3 py-2 text-sm text-text-primary shadow-lg',
|
||||
'max-w-[calc(100vw-4rem)] sm:max-h-[calc(65vh)] sm:max-w-[400px]',
|
||||
searchable && 'p-0',
|
||||
)}
|
||||
>
|
||||
<SearchableContext.Provider value={searchable}>
|
||||
{searchable ? (
|
||||
<>
|
||||
<div className="sticky top-0 z-10 bg-inherit p-1">
|
||||
<Ariakit.Combobox
|
||||
autoSelect
|
||||
render={combobox}
|
||||
className={cn(
|
||||
'h-10 w-full rounded-lg border-none bg-transparent px-2 text-base',
|
||||
'sm:h-8 sm:text-sm',
|
||||
'focus:outline-none focus:ring-0 focus-visible:ring-2 focus-visible:ring-white',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<Ariakit.ComboboxList className="p-0.5 pt-0">{children}</Ariakit.ComboboxList>
|
||||
</>
|
||||
) : (
|
||||
children
|
||||
)}
|
||||
</SearchableContext.Provider>
|
||||
</Ariakit.Menu>
|
||||
</Ariakit.MenuProvider>
|
||||
);
|
||||
|
||||
if (searchable) {
|
||||
return (
|
||||
<Ariakit.ComboboxProvider
|
||||
resetValueOnHide
|
||||
includesBaseElement={false}
|
||||
value={searchValue}
|
||||
setValue={onSearch}
|
||||
>
|
||||
{element}
|
||||
</Ariakit.ComboboxProvider>
|
||||
);
|
||||
}
|
||||
|
||||
return element;
|
||||
});
|
||||
|
||||
export const CustomMenuSeparator = React.forwardRef<HTMLHRElement, Ariakit.MenuSeparatorProps>(
|
||||
function CustomMenuSeparator(props, ref) {
|
||||
return (
|
||||
<Ariakit.MenuSeparator
|
||||
ref={ref}
|
||||
{...props}
|
||||
className={cn(
|
||||
'my-0.5 h-0 w-full border-t border-slate-200 dark:border-slate-700',
|
||||
props.className,
|
||||
)}
|
||||
/>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
export interface CustomMenuGroupProps extends Ariakit.MenuGroupProps {
|
||||
label?: React.ReactNode;
|
||||
}
|
||||
|
||||
export const CustomMenuGroup = React.forwardRef<HTMLDivElement, CustomMenuGroupProps>(
|
||||
function CustomMenuGroup({ label, ...props }, ref) {
|
||||
return (
|
||||
<Ariakit.MenuGroup ref={ref} {...props} className={cn('', props.className)}>
|
||||
{label && (
|
||||
<Ariakit.MenuGroupLabel className="cursor-default p-2 text-sm font-medium opacity-60 sm:py-1 sm:text-xs">
|
||||
{label}
|
||||
</Ariakit.MenuGroupLabel>
|
||||
)}
|
||||
{props.children}
|
||||
</Ariakit.MenuGroup>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
const SearchableContext = React.createContext(false);
|
||||
|
||||
export interface CustomMenuItemProps extends Omit<Ariakit.ComboboxItemProps, 'store'> {
|
||||
name?: string;
|
||||
}
|
||||
|
||||
export const CustomMenuItem = React.forwardRef<HTMLDivElement, CustomMenuItemProps>(
|
||||
function CustomMenuItem({ name, value, ...props }, ref) {
|
||||
const menu = Ariakit.useMenuContext();
|
||||
const searchable = React.useContext(SearchableContext);
|
||||
const defaultProps: CustomMenuItemProps = {
|
||||
ref,
|
||||
focusOnHover: true,
|
||||
blurOnHoverEnd: false,
|
||||
...props,
|
||||
className: cn(
|
||||
'flex cursor-default items-center gap-2 rounded-lg p-2 outline-none! scroll-m-1 scroll-mt-[calc(var(--combobox-height,0px)+var(--label-height,4px))] aria-disabled:opacity-25 data-[active-item]:bg-black/[0.075] data-[active-item]:text-black dark:data-[active-item]:bg-white/10 dark:data-[active-item]:text-white sm:py-1 sm:text-sm min-w-0 w-full',
|
||||
props.className,
|
||||
),
|
||||
};
|
||||
|
||||
const checkable = Ariakit.useStoreState(menu, (state) => {
|
||||
if (!name) {
|
||||
return false;
|
||||
}
|
||||
if (value == null) {
|
||||
return false;
|
||||
}
|
||||
return state?.values[name] != null;
|
||||
});
|
||||
|
||||
const checked = Ariakit.useStoreState(menu, (state) => {
|
||||
if (!name) {
|
||||
return false;
|
||||
}
|
||||
return state?.values[name] === value;
|
||||
});
|
||||
|
||||
// If the item is checkable, we render a checkmark icon next to the label.
|
||||
if (checkable) {
|
||||
defaultProps.children = (
|
||||
<React.Fragment>
|
||||
<span className="flex-1">{defaultProps.children}</span>
|
||||
<Ariakit.MenuItemCheck checked={checked} />
|
||||
{searchable && (
|
||||
// When an item is displayed in a search menu as a role=option
|
||||
// element instead of a role=menuitemradio, we can't depend on the
|
||||
// aria-checked attribute. Although NVDA and JAWS announce it
|
||||
// accurately, VoiceOver doesn't. TalkBack does announce the checked
|
||||
// state, but misleadingly implies that a double tap will change the
|
||||
// state, which isn't the case. Therefore, we use a visually hidden
|
||||
// element to indicate whether the item is checked or not, ensuring
|
||||
// cross-browser/AT compatibility.
|
||||
<Ariakit.VisuallyHidden>{checked ? 'checked' : 'not checked'}</Ariakit.VisuallyHidden>
|
||||
)}
|
||||
</React.Fragment>
|
||||
);
|
||||
}
|
||||
|
||||
// If the item is not rendered in a search menu (listbox), we can render it
|
||||
// as a MenuItem/MenuItemRadio.
|
||||
if (!searchable) {
|
||||
if (name != null && value != null) {
|
||||
const radioProps = { ...defaultProps, name, value, hideOnClick: true };
|
||||
return <Ariakit.MenuItemRadio {...radioProps} />;
|
||||
}
|
||||
return <Ariakit.MenuItem {...defaultProps} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Ariakit.ComboboxItem
|
||||
{...defaultProps}
|
||||
setValueOnClick={false}
|
||||
value={checkable ? value : undefined}
|
||||
selectValueOnClick={() => {
|
||||
if (name == null || value == null) {
|
||||
return false;
|
||||
}
|
||||
// By default, clicking on a ComboboxItem will update the
|
||||
// selectedValue state of the combobox. However, since we're sharing
|
||||
// state between combobox and menu, we also need to update the menu's
|
||||
// values state.
|
||||
menu?.setValue(name, value);
|
||||
return true;
|
||||
}}
|
||||
hideOnClick={(event) => {
|
||||
// Make sure that clicking on a combobox item that opens a nested
|
||||
// menu/dialog does not close the menu.
|
||||
const expandable = event.currentTarget.hasAttribute('aria-expanded');
|
||||
if (expandable) {
|
||||
return false;
|
||||
}
|
||||
// By default, clicking on a ComboboxItem only closes its own popover.
|
||||
// However, since we're in a menu context, we also close all parent
|
||||
// menus.
|
||||
menu?.hideAll();
|
||||
return true;
|
||||
}}
|
||||
/>
|
||||
);
|
||||
},
|
||||
);
|
||||
34
client/src/components/Chat/Menus/Endpoints/DialogManager.tsx
Normal file
34
client/src/components/Chat/Menus/Endpoints/DialogManager.tsx
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import React from 'react';
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
import { SetKeyDialog } from '~/components/Input/SetKeyDialog';
|
||||
import { getEndpointField } from '~/utils';
|
||||
|
||||
interface DialogManagerProps {
|
||||
keyDialogOpen: boolean;
|
||||
keyDialogEndpoint?: EModelEndpoint;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
endpointsConfig: Record<string, any>;
|
||||
}
|
||||
|
||||
const DialogManager = ({
|
||||
keyDialogOpen,
|
||||
keyDialogEndpoint,
|
||||
onOpenChange,
|
||||
endpointsConfig,
|
||||
}: DialogManagerProps) => {
|
||||
return (
|
||||
<>
|
||||
{keyDialogEndpoint && (
|
||||
<SetKeyDialog
|
||||
open={keyDialogOpen}
|
||||
endpoint={keyDialogEndpoint}
|
||||
endpointType={getEndpointField(endpointsConfig, keyDialogEndpoint, 'type')}
|
||||
onOpenChange={onOpenChange}
|
||||
userProvideURL={getEndpointField(endpointsConfig, keyDialogEndpoint, 'userProvideURL')}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default DialogManager;
|
||||
|
|
@ -1,221 +0,0 @@
|
|||
import { useState } from 'react';
|
||||
import { Settings } from 'lucide-react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { FC } from 'react';
|
||||
import { cn, getConvoSwitchLogic, getEndpointField, getIconKey } from '~/utils';
|
||||
import { useLocalize, useUserKey, useDefaultConvo } from '~/hooks';
|
||||
import { SetKeyDialog } from '~/components/Input/SetKeyDialog';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { icons } from './Icons';
|
||||
import store from '~/store';
|
||||
|
||||
type MenuItemProps = {
|
||||
title: string;
|
||||
value: EModelEndpoint;
|
||||
selected: boolean;
|
||||
description?: string;
|
||||
userProvidesKey: boolean;
|
||||
// iconPath: string;
|
||||
// hoverContent?: string;
|
||||
};
|
||||
|
||||
const MenuItem: FC<MenuItemProps> = ({
|
||||
title,
|
||||
value: endpoint,
|
||||
description,
|
||||
selected,
|
||||
userProvidesKey,
|
||||
...rest
|
||||
}) => {
|
||||
const modularChat = useRecoilValue(store.modularChat);
|
||||
const [isDialogOpen, setDialogOpen] = useState(false);
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const { conversation, newConversation } = useChatContext();
|
||||
const getDefaultConversation = useDefaultConvo();
|
||||
|
||||
const { getExpiry } = useUserKey(endpoint);
|
||||
const localize = useLocalize();
|
||||
const expiryTime = getExpiry() ?? '';
|
||||
|
||||
const onSelectEndpoint = (newEndpoint?: EModelEndpoint) => {
|
||||
if (!newEndpoint) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!expiryTime) {
|
||||
setDialogOpen(true);
|
||||
}
|
||||
|
||||
const {
|
||||
template,
|
||||
shouldSwitch,
|
||||
isNewModular,
|
||||
newEndpointType,
|
||||
isCurrentModular,
|
||||
isExistingConversation,
|
||||
} = getConvoSwitchLogic({
|
||||
newEndpoint,
|
||||
modularChat,
|
||||
conversation,
|
||||
endpointsConfig,
|
||||
});
|
||||
|
||||
const isModular = isCurrentModular && isNewModular && shouldSwitch;
|
||||
if (isExistingConversation && isModular) {
|
||||
template.endpointType = newEndpointType;
|
||||
|
||||
const currentConvo = getDefaultConversation({
|
||||
/* target endpointType is necessary to avoid endpoint mixing */
|
||||
conversation: { ...(conversation ?? {}), endpointType: template.endpointType },
|
||||
preset: template,
|
||||
});
|
||||
|
||||
/* We don't reset the latest message, only when changing settings mid-converstion */
|
||||
newConversation({
|
||||
template: currentConvo,
|
||||
preset: currentConvo,
|
||||
keepLatestMessage: true,
|
||||
keepAddedConvos: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
newConversation({
|
||||
template: { ...(template as Partial<TConversation>) },
|
||||
keepAddedConvos: isModular,
|
||||
});
|
||||
};
|
||||
|
||||
const endpointType = getEndpointField(endpointsConfig, endpoint, 'type');
|
||||
const iconKey = getIconKey({ endpoint, endpointsConfig, endpointType });
|
||||
const Icon = icons[iconKey];
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
role="option"
|
||||
aria-selected={selected}
|
||||
className={cn(
|
||||
'group m-1.5 flex max-h-[40px] cursor-pointer gap-2 rounded px-5 py-2.5 !pr-3 text-sm !opacity-100 hover:bg-surface-hover',
|
||||
'radix-disabled:pointer-events-none radix-disabled:opacity-50',
|
||||
)}
|
||||
tabIndex={0}
|
||||
{...rest}
|
||||
onClick={() => onSelectEndpoint(endpoint)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault();
|
||||
onSelectEndpoint(endpoint);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex grow items-center justify-between gap-2">
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
{Icon != null && (
|
||||
<Icon
|
||||
size={18}
|
||||
endpoint={endpoint}
|
||||
context={'menu-item'}
|
||||
className="icon-md shrink-0 dark:text-white"
|
||||
iconURL={getEndpointField(endpointsConfig, endpoint, 'iconURL')}
|
||||
/>
|
||||
)}
|
||||
<div>
|
||||
{title}
|
||||
<div className="text-token-text-tertiary">{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{userProvidesKey ? (
|
||||
<div className="text-token-text-primary" key={`set-key-${endpoint}`}>
|
||||
<button
|
||||
tabIndex={0}
|
||||
aria-label={`${localize('com_endpoint_config_key')} for ${title}`}
|
||||
className={cn(
|
||||
'invisible flex gap-x-1 group-focus-within:visible group-hover:visible',
|
||||
selected ? 'visible' : '',
|
||||
expiryTime ? 'text-token-text-primary w-full rounded-lg p-2' : '',
|
||||
)}
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setDialogOpen(true);
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setDialogOpen(true);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
'invisible group-focus-within:visible group-hover:visible',
|
||||
expiryTime ? 'text-xs' : '',
|
||||
)}
|
||||
>
|
||||
{localize('com_endpoint_config_key')}
|
||||
</div>
|
||||
<Settings className={cn(expiryTime ? 'icon-sm' : 'icon-md stroke-1')} />
|
||||
</button>
|
||||
</div>
|
||||
) : null}
|
||||
{selected && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="icon-md block group-hover:hidden"
|
||||
>
|
||||
<path
|
||||
fillRule="evenodd"
|
||||
clipRule="evenodd"
|
||||
d="M2 12C2 6.47715 6.47715 2 12 2C17.5228 2 22 6.47715 22 12C22 17.5228 17.5228 22 12 22C6.47715 22 2 17.5228 2 12ZM16.0755 7.93219C16.5272 8.25003 16.6356 8.87383 16.3178 9.32549L11.5678 16.0755C11.3931 16.3237 11.1152 16.4792 10.8123 16.4981C10.5093 16.517 10.2142 16.3973 10.0101 16.1727L7.51006 13.4227C7.13855 13.014 7.16867 12.3816 7.57733 12.0101C7.98598 11.6386 8.61843 11.6687 8.98994 12.0773L10.6504 13.9039L14.6822 8.17451C15 7.72284 15.6238 7.61436 16.0755 7.93219Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
{(!userProvidesKey || expiryTime) && (
|
||||
<div className="text-token-text-primary hidden gap-x-1 group-hover:flex ">
|
||||
{!userProvidesKey && <div className="">{localize('com_ui_new_chat')}</div>}
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="icon-md"
|
||||
>
|
||||
<path
|
||||
fillRule="evenodd"
|
||||
clipRule="evenodd"
|
||||
d="M16.7929 2.79289C18.0118 1.57394 19.9882 1.57394 21.2071 2.79289C22.4261 4.01184 22.4261 5.98815 21.2071 7.20711L12.7071 15.7071C12.5196 15.8946 12.2652 16 12 16H9C8.44772 16 8 15.5523 8 15V12C8 11.7348 8.10536 11.4804 8.29289 11.2929L16.7929 2.79289ZM19.7929 4.20711C19.355 3.7692 18.645 3.7692 18.2071 4.2071L10 12.4142V14H11.5858L19.7929 5.79289C20.2308 5.35499 20.2308 4.64501 19.7929 4.20711ZM6 5C5.44772 5 5 5.44771 5 6V18C5 18.5523 5.44772 19 6 19H18C18.5523 19 19 18.5523 19 18V14C19 13.4477 19.4477 13 20 13C20.5523 13 21 13.4477 21 14V18C21 19.6569 19.6569 21 18 21H6C4.34315 21 3 19.6569 3 18V6C3 4.34314 4.34315 3 6 3H10C10.5523 3 11 3.44771 11 4C11 4.55228 10.5523 5 10 5H6Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{userProvidesKey && (
|
||||
<SetKeyDialog
|
||||
open={isDialogOpen}
|
||||
endpoint={endpoint}
|
||||
endpointType={endpointType}
|
||||
onOpenChange={setDialogOpen}
|
||||
userProvideURL={getEndpointField(endpointsConfig, endpoint, 'userProvideURL')}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default MenuItem;
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import type { FC } from 'react';
|
||||
import { Close } from '@radix-ui/react-popover';
|
||||
import {
|
||||
EModelEndpoint,
|
||||
alternateName,
|
||||
PermissionTypes,
|
||||
Permissions,
|
||||
} from 'librechat-data-provider';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import MenuSeparator from '../UI/MenuSeparator';
|
||||
import { getEndpointField } from '~/utils';
|
||||
import { useHasAccess } from '~/hooks';
|
||||
import MenuItem from './MenuItem';
|
||||
|
||||
const EndpointItems: FC<{
|
||||
endpoints: Array<EModelEndpoint | undefined>;
|
||||
selected: EModelEndpoint | '';
|
||||
}> = ({ endpoints = [], selected }) => {
|
||||
const hasAccessToAgents = useHasAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
return (
|
||||
<>
|
||||
{endpoints.map((endpoint, i) => {
|
||||
if (!endpoint) {
|
||||
return null;
|
||||
} else if (!endpointsConfig?.[endpoint]) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (endpoint === EModelEndpoint.agents && !hasAccessToAgents) {
|
||||
return null;
|
||||
}
|
||||
const userProvidesKey: boolean | null | undefined =
|
||||
getEndpointField(endpointsConfig, endpoint, 'userProvide') ?? false;
|
||||
return (
|
||||
<Close asChild key={`endpoint-${endpoint}`}>
|
||||
<div key={`endpoint-${endpoint}`}>
|
||||
<MenuItem
|
||||
key={`endpoint-item-${endpoint}`}
|
||||
title={alternateName[endpoint] || endpoint}
|
||||
value={endpoint}
|
||||
selected={selected === endpoint}
|
||||
data-testid={`endpoint-item-${endpoint}`}
|
||||
userProvidesKey={!!userProvidesKey}
|
||||
// description="With DALL·E, browsing and analysis"
|
||||
/>
|
||||
{i !== endpoints.length - 1 && <MenuSeparator />}
|
||||
</div>
|
||||
</Close>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default EndpointItems;
|
||||
107
client/src/components/Chat/Menus/Endpoints/ModelSelector.tsx
Normal file
107
client/src/components/Chat/Menus/Endpoints/ModelSelector.tsx
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import React, { useMemo } from 'react';
|
||||
import type { ModelSelectorProps } from '~/common';
|
||||
import { ModelSelectorProvider, useModelSelectorContext } from './ModelSelectorContext';
|
||||
import { renderModelSpecs, renderEndpoints, renderSearchResults } from './components';
|
||||
import { getSelectedIcon, getDisplayValue } from './utils';
|
||||
import { CustomMenu as Menu } from './CustomMenu';
|
||||
import DialogManager from './DialogManager';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
function ModelSelectorContent() {
|
||||
const localize = useLocalize();
|
||||
|
||||
const {
|
||||
// LibreChat
|
||||
modelSpecs,
|
||||
mappedEndpoints,
|
||||
endpointsConfig,
|
||||
// State
|
||||
searchValue,
|
||||
searchResults,
|
||||
selectedValues,
|
||||
|
||||
// Functions
|
||||
setSearchValue,
|
||||
setSelectedValues,
|
||||
// Dialog
|
||||
keyDialogOpen,
|
||||
onOpenChange,
|
||||
keyDialogEndpoint,
|
||||
} = useModelSelectorContext();
|
||||
|
||||
const selectedIcon = useMemo(
|
||||
() =>
|
||||
getSelectedIcon({
|
||||
mappedEndpoints: mappedEndpoints ?? [],
|
||||
selectedValues,
|
||||
modelSpecs,
|
||||
endpointsConfig,
|
||||
}),
|
||||
[mappedEndpoints, selectedValues, modelSpecs, endpointsConfig],
|
||||
);
|
||||
const selectedDisplayValue = useMemo(
|
||||
() =>
|
||||
getDisplayValue({
|
||||
localize,
|
||||
modelSpecs,
|
||||
selectedValues,
|
||||
mappedEndpoints,
|
||||
}),
|
||||
[localize, modelSpecs, selectedValues, mappedEndpoints],
|
||||
);
|
||||
|
||||
const trigger = (
|
||||
<button
|
||||
className="my-1 flex h-10 w-full max-w-[70vw] items-center justify-center gap-2 rounded-xl border border-border-light bg-surface-secondary px-3 py-2 text-sm text-text-primary hover:bg-surface-tertiary focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-white"
|
||||
aria-label={localize('com_ui_select_model')}
|
||||
>
|
||||
{selectedIcon && React.isValidElement(selectedIcon) && (
|
||||
<div className="flex flex-shrink-0 items-center justify-center overflow-hidden">
|
||||
{selectedIcon}
|
||||
</div>
|
||||
)}
|
||||
<span className="flex-grow truncate text-left">{selectedDisplayValue}</span>
|
||||
</button>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="relative flex w-full max-w-md flex-col items-center gap-2">
|
||||
<Menu
|
||||
values={selectedValues}
|
||||
onValuesChange={(values: Record<string, any>) => {
|
||||
setSelectedValues({
|
||||
endpoint: values.endpoint || '',
|
||||
model: values.model || '',
|
||||
modelSpec: values.modelSpec || '',
|
||||
});
|
||||
}}
|
||||
onSearch={(value) => setSearchValue(value)}
|
||||
combobox={<input placeholder={localize('com_endpoint_search_models')} />}
|
||||
trigger={trigger}
|
||||
>
|
||||
{searchResults ? (
|
||||
renderSearchResults(searchResults, localize, searchValue)
|
||||
) : (
|
||||
<>
|
||||
{renderModelSpecs(modelSpecs, selectedValues.modelSpec || '')}
|
||||
{renderEndpoints(mappedEndpoints ?? [])}
|
||||
</>
|
||||
)}
|
||||
</Menu>
|
||||
<DialogManager
|
||||
keyDialogOpen={keyDialogOpen}
|
||||
onOpenChange={onOpenChange}
|
||||
endpointsConfig={endpointsConfig || {}}
|
||||
keyDialogEndpoint={keyDialogEndpoint || undefined}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function ModelSelector({ startupConfig }: ModelSelectorProps) {
|
||||
return (
|
||||
<ModelSelectorProvider startupConfig={startupConfig}>
|
||||
<ModelSelectorContent />
|
||||
</ModelSelectorProvider>
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
import debounce from 'lodash/debounce';
|
||||
import React, { createContext, useContext, useState, useMemo } from 'react';
|
||||
import { isAgentsEndpoint, isAssistantsEndpoint } from 'librechat-data-provider';
|
||||
import type * as t from 'librechat-data-provider';
|
||||
import type { Endpoint, SelectedValues } from '~/common';
|
||||
import { useAgentsMapContext, useAssistantsMapContext, useChatContext } from '~/Providers';
|
||||
import { useEndpoints, useSelectorEffects, useKeyDialog } from '~/hooks';
|
||||
import useSelectMention from '~/hooks/Input/useSelectMention';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { filterItems } from './utils';
|
||||
|
||||
type ModelSelectorContextType = {
|
||||
// State
|
||||
searchValue: string;
|
||||
selectedValues: SelectedValues;
|
||||
endpointSearchValues: Record<string, string>;
|
||||
searchResults: (t.TModelSpec | Endpoint)[] | null;
|
||||
// LibreChat
|
||||
modelSpecs: t.TModelSpec[];
|
||||
mappedEndpoints: Endpoint[];
|
||||
agentsMap: t.TAgentsMap | undefined;
|
||||
assistantsMap: t.TAssistantsMap | undefined;
|
||||
endpointsConfig: t.TEndpointsConfig;
|
||||
|
||||
// Functions
|
||||
endpointRequiresUserKey: (endpoint: string) => boolean;
|
||||
setSelectedValues: React.Dispatch<React.SetStateAction<SelectedValues>>;
|
||||
setSearchValue: (value: string) => void;
|
||||
setEndpointSearchValue: (endpoint: string, value: string) => void;
|
||||
handleSelectSpec: (spec: t.TModelSpec) => void;
|
||||
handleSelectEndpoint: (endpoint: Endpoint) => void;
|
||||
handleSelectModel: (endpoint: Endpoint, model: string) => void;
|
||||
} & ReturnType<typeof useKeyDialog>;
|
||||
|
||||
const ModelSelectorContext = createContext<ModelSelectorContextType | undefined>(undefined);
|
||||
|
||||
export function useModelSelectorContext() {
|
||||
const context = useContext(ModelSelectorContext);
|
||||
if (context === undefined) {
|
||||
throw new Error('useModelSelectorContext must be used within a ModelSelectorProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
interface ModelSelectorProviderProps {
|
||||
children: React.ReactNode;
|
||||
startupConfig: t.TStartupConfig | undefined;
|
||||
}
|
||||
|
||||
export function ModelSelectorProvider({ children, startupConfig }: ModelSelectorProviderProps) {
|
||||
const agentsMap = useAgentsMapContext();
|
||||
const assistantsMap = useAssistantsMapContext();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const { conversation, newConversation } = useChatContext();
|
||||
const modelSpecs = useMemo(() => startupConfig?.modelSpecs?.list ?? [], [startupConfig]);
|
||||
const { mappedEndpoints, endpointRequiresUserKey } = useEndpoints({
|
||||
agentsMap,
|
||||
assistantsMap,
|
||||
startupConfig,
|
||||
endpointsConfig,
|
||||
});
|
||||
const { onSelectEndpoint, onSelectSpec } = useSelectMention({
|
||||
// presets,
|
||||
modelSpecs,
|
||||
assistantsMap,
|
||||
endpointsConfig,
|
||||
newConversation,
|
||||
returnHandlers: true,
|
||||
});
|
||||
|
||||
// State
|
||||
const [selectedValues, setSelectedValues] = useState<SelectedValues>({
|
||||
endpoint: conversation?.endpoint || '',
|
||||
model: conversation?.model || '',
|
||||
modelSpec: conversation?.spec || '',
|
||||
});
|
||||
useSelectorEffects({
|
||||
agentsMap,
|
||||
conversation,
|
||||
assistantsMap,
|
||||
setSelectedValues,
|
||||
});
|
||||
|
||||
const [searchValue, setSearchValueState] = useState('');
|
||||
const [endpointSearchValues, setEndpointSearchValues] = useState<Record<string, string>>({});
|
||||
|
||||
const keyProps = useKeyDialog();
|
||||
|
||||
// Memoized search results
|
||||
const searchResults = useMemo(() => {
|
||||
if (!searchValue) {
|
||||
return null;
|
||||
}
|
||||
const allItems = [...modelSpecs, ...mappedEndpoints];
|
||||
return filterItems(allItems, searchValue, agentsMap, assistantsMap || {});
|
||||
}, [searchValue, modelSpecs, mappedEndpoints, agentsMap, assistantsMap]);
|
||||
|
||||
// Functions
|
||||
const setDebouncedSearchValue = useMemo(
|
||||
() =>
|
||||
debounce((value: string) => {
|
||||
setSearchValueState(value);
|
||||
}, 200),
|
||||
[],
|
||||
);
|
||||
const setEndpointSearchValue = (endpoint: string, value: string) => {
|
||||
setEndpointSearchValues((prev) => ({
|
||||
...prev,
|
||||
[endpoint]: value,
|
||||
}));
|
||||
};
|
||||
|
||||
const handleSelectSpec = (spec: t.TModelSpec) => {
|
||||
let model = spec.preset.model ?? null;
|
||||
onSelectSpec?.(spec);
|
||||
if (isAgentsEndpoint(spec.preset.endpoint)) {
|
||||
model = spec.preset.agent_id ?? '';
|
||||
} else if (isAssistantsEndpoint(spec.preset.endpoint)) {
|
||||
model = spec.preset.assistant_id ?? '';
|
||||
}
|
||||
setSelectedValues({
|
||||
endpoint: spec.preset.endpoint,
|
||||
model,
|
||||
modelSpec: spec.name,
|
||||
});
|
||||
};
|
||||
|
||||
const handleSelectEndpoint = (endpoint: Endpoint) => {
|
||||
if (!endpoint.hasModels) {
|
||||
if (endpoint.value) {
|
||||
onSelectEndpoint?.(endpoint.value);
|
||||
}
|
||||
setSelectedValues({
|
||||
endpoint: endpoint.value,
|
||||
model: '',
|
||||
modelSpec: '',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelectModel = (endpoint: Endpoint, model: string) => {
|
||||
if (isAgentsEndpoint(endpoint.value)) {
|
||||
onSelectEndpoint?.(endpoint.value, {
|
||||
agent_id: model,
|
||||
model: agentsMap?.[model]?.model ?? '',
|
||||
});
|
||||
} else if (isAssistantsEndpoint(endpoint.value)) {
|
||||
onSelectEndpoint?.(endpoint.value, {
|
||||
assistant_id: model,
|
||||
model: assistantsMap?.[endpoint.value]?.[model]?.model ?? '',
|
||||
});
|
||||
} else if (endpoint.value) {
|
||||
onSelectEndpoint?.(endpoint.value, { model });
|
||||
}
|
||||
setSelectedValues({
|
||||
endpoint: endpoint.value,
|
||||
model,
|
||||
modelSpec: '',
|
||||
});
|
||||
};
|
||||
|
||||
const value = {
|
||||
// State
|
||||
searchValue,
|
||||
searchResults,
|
||||
selectedValues,
|
||||
endpointSearchValues,
|
||||
// LibreChat
|
||||
agentsMap,
|
||||
modelSpecs,
|
||||
assistantsMap,
|
||||
mappedEndpoints,
|
||||
endpointsConfig,
|
||||
|
||||
// Functions
|
||||
handleSelectSpec,
|
||||
handleSelectModel,
|
||||
setSelectedValues,
|
||||
handleSelectEndpoint,
|
||||
setEndpointSearchValue,
|
||||
endpointRequiresUserKey,
|
||||
setSearchValue: setDebouncedSearchValue,
|
||||
// Dialog
|
||||
...keyProps,
|
||||
};
|
||||
|
||||
return <ModelSelectorContext.Provider value={value}>{children}</ModelSelectorContext.Provider>;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue