mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 08:50:15 +01:00
🔎 feat: Add Prompt and Agent Permissions Migration Checks (#9063)
* chore: fix mock typing in packages/api tests * chore: improve imports, type handling and method signatures for MCPServersRegistry * chore: use enum in migration scripts * chore: ParsedServerConfig type to enhance server configuration handling * feat: Implement agent permissions migration check and logging * feat: Integrate migration checks into server initialization process * feat: Add prompt permissions migration check and logging to server initialization * chore: move prompt formatting functions to dedicated prompts dir
This commit is contained in:
parent
e8ddd279fd
commit
e4e25aaf2b
17 changed files with 636 additions and 96 deletions
|
|
@ -14,6 +14,7 @@ const { isEnabled, ErrorController } = require('@librechat/api');
|
||||||
const { connectDb, indexSync } = require('~/db');
|
const { connectDb, indexSync } = require('~/db');
|
||||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||||
|
const { checkMigrations } = require('./services/start/migration');
|
||||||
const initializeMCPs = require('./services/initializeMCPs');
|
const initializeMCPs = require('./services/initializeMCPs');
|
||||||
const configureSocialLogins = require('./socialLogins');
|
const configureSocialLogins = require('./socialLogins');
|
||||||
const AppService = require('./services/AppService');
|
const AppService = require('./services/AppService');
|
||||||
|
|
@ -145,7 +146,7 @@ const startServer = async () => {
|
||||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
initializeMCPs(app);
|
initializeMCPs(app).then(() => checkMigrations());
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,9 @@
|
||||||
const { loadMemoryConfig, agentsConfigSetup, loadWebSearchConfig } = require('@librechat/api');
|
const {
|
||||||
|
isEnabled,
|
||||||
|
loadMemoryConfig,
|
||||||
|
agentsConfigSetup,
|
||||||
|
loadWebSearchConfig,
|
||||||
|
} = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
FileSources,
|
FileSources,
|
||||||
loadOCRConfig,
|
loadOCRConfig,
|
||||||
|
|
@ -6,16 +11,16 @@ const {
|
||||||
getConfigDefaults,
|
getConfigDefaults,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
|
checkWebSearchConfig,
|
||||||
|
checkAzureVariables,
|
||||||
|
checkVariables,
|
||||||
checkHealth,
|
checkHealth,
|
||||||
checkConfig,
|
checkConfig,
|
||||||
checkVariables,
|
|
||||||
checkAzureVariables,
|
|
||||||
checkWebSearchConfig,
|
|
||||||
} = require('./start/checks');
|
} = require('./start/checks');
|
||||||
|
const { ensureDefaultCategories, seedDefaultRoles, initializeRoles } = require('~/models');
|
||||||
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
|
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
|
||||||
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
|
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
|
||||||
const { initializeFirebase } = require('./Files/Firebase/initialize');
|
const { initializeFirebase } = require('./Files/Firebase/initialize');
|
||||||
const { seedDefaultRoles, initializeRoles, ensureDefaultCategories } = require('~/models');
|
|
||||||
const loadCustomConfig = require('./Config/loadCustomConfig');
|
const loadCustomConfig = require('./Config/loadCustomConfig');
|
||||||
const handleRateLimits = require('./Config/handleRateLimits');
|
const handleRateLimits = require('./Config/handleRateLimits');
|
||||||
const { loadDefaultInterface } = require('./start/interface');
|
const { loadDefaultInterface } = require('./start/interface');
|
||||||
|
|
@ -24,7 +29,6 @@ const { azureConfigSetup } = require('./start/azureOpenAI');
|
||||||
const { processModelSpecs } = require('./start/modelSpecs');
|
const { processModelSpecs } = require('./start/modelSpecs');
|
||||||
const { initializeS3 } = require('./Files/S3/initialize');
|
const { initializeS3 } = require('./Files/S3/initialize');
|
||||||
const { loadAndFormatTools } = require('./ToolService');
|
const { loadAndFormatTools } = require('./ToolService');
|
||||||
const { isEnabled } = require('~/server/utils');
|
|
||||||
const { setCachedTools } = require('./Config');
|
const { setCachedTools } = require('./Config');
|
||||||
const paths = require('~/config/paths');
|
const paths = require('~/config/paths');
|
||||||
|
|
||||||
|
|
|
||||||
45
api/server/services/start/migration.js
Normal file
45
api/server/services/start/migration.js
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const {
|
||||||
|
logAgentMigrationWarning,
|
||||||
|
logPromptMigrationWarning,
|
||||||
|
checkAgentPermissionsMigration,
|
||||||
|
checkPromptPermissionsMigration,
|
||||||
|
} = require('@librechat/api');
|
||||||
|
const { getProjectByName } = require('~/models/Project');
|
||||||
|
const { Agent, PromptGroup } = require('~/db/models');
|
||||||
|
const { findRoleByIdentifier } = require('~/models');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if permissions migrations are needed for shared resources
|
||||||
|
* This runs at the end to ensure all systems are initialized
|
||||||
|
*/
|
||||||
|
async function checkMigrations() {
|
||||||
|
try {
|
||||||
|
const agentMigrationResult = await checkAgentPermissionsMigration({
|
||||||
|
db: {
|
||||||
|
findRoleByIdentifier,
|
||||||
|
getProjectByName,
|
||||||
|
},
|
||||||
|
AgentModel: Agent,
|
||||||
|
});
|
||||||
|
logAgentMigrationWarning(agentMigrationResult);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to check agent permissions migration:', error);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const promptMigrationResult = await checkPromptPermissionsMigration({
|
||||||
|
db: {
|
||||||
|
findRoleByIdentifier,
|
||||||
|
getProjectByName,
|
||||||
|
},
|
||||||
|
PromptGroupModel: PromptGroup,
|
||||||
|
});
|
||||||
|
logPromptMigrationWarning(promptMigrationResult);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to check prompt permissions migration:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
checkMigrations,
|
||||||
|
};
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
//TODO: needs testing and validation before running in production
|
|
||||||
console.log('needs testing and validation before running in production...');
|
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
|
||||||
|
|
||||||
const { AccessRoleIds, ResourceType, PrincipalType } = require('librechat-data-provider');
|
const { AccessRoleIds, ResourceType, PrincipalType } = require('librechat-data-provider');
|
||||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||||
|
|
||||||
|
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
||||||
const connect = require('./connect');
|
const connect = require('./connect');
|
||||||
|
|
||||||
const { grantPermission } = require('~/server/services/PermissionService');
|
const { grantPermission } = require('~/server/services/PermissionService');
|
||||||
|
|
@ -19,9 +17,9 @@ async function migrateAgentPermissionsEnhanced({ dryRun = true, batchSize = 100
|
||||||
logger.info('Starting Enhanced Agent Permissions Migration', { dryRun, batchSize });
|
logger.info('Starting Enhanced Agent Permissions Migration', { dryRun, batchSize });
|
||||||
|
|
||||||
// Verify required roles exist
|
// Verify required roles exist
|
||||||
const ownerRole = await findRoleByIdentifier('agent_owner');
|
const ownerRole = await findRoleByIdentifier(AccessRoleIds.AGENT_OWNER);
|
||||||
const viewerRole = await findRoleByIdentifier('agent_viewer');
|
const viewerRole = await findRoleByIdentifier(AccessRoleIds.AGENT_VIEWER);
|
||||||
const editorRole = await findRoleByIdentifier('agent_editor');
|
const editorRole = await findRoleByIdentifier(AccessRoleIds.AGENT_EDITOR);
|
||||||
|
|
||||||
if (!ownerRole || !viewerRole || !editorRole) {
|
if (!ownerRole || !viewerRole || !editorRole) {
|
||||||
throw new Error('Required roles not found. Run role seeding first.');
|
throw new Error('Required roles not found. Run role seeding first.');
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
|
||||||
|
|
||||||
const { AccessRoleIds, ResourceType, PrincipalType } = require('librechat-data-provider');
|
const { AccessRoleIds, ResourceType, PrincipalType } = require('librechat-data-provider');
|
||||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||||
|
|
||||||
|
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
||||||
const connect = require('./connect');
|
const connect = require('./connect');
|
||||||
|
|
||||||
const { grantPermission } = require('~/server/services/PermissionService');
|
const { grantPermission } = require('~/server/services/PermissionService');
|
||||||
|
|
@ -17,9 +17,9 @@ async function migrateToPromptGroupPermissions({ dryRun = true, batchSize = 100
|
||||||
logger.info('Starting PromptGroup Permissions Migration', { dryRun, batchSize });
|
logger.info('Starting PromptGroup Permissions Migration', { dryRun, batchSize });
|
||||||
|
|
||||||
// Verify required roles exist
|
// Verify required roles exist
|
||||||
const ownerRole = await findRoleByIdentifier('promptGroup_owner');
|
const ownerRole = await findRoleByIdentifier(AccessRoleIds.PROMPTGROUP_OWNER);
|
||||||
const viewerRole = await findRoleByIdentifier('promptGroup_viewer');
|
const viewerRole = await findRoleByIdentifier(AccessRoleIds.PROMPTGROUP_VIEWER);
|
||||||
const editorRole = await findRoleByIdentifier('promptGroup_editor');
|
const editorRole = await findRoleByIdentifier(AccessRoleIds.PROMPTGROUP_EDITOR);
|
||||||
|
|
||||||
if (!ownerRole || !viewerRole || !editorRole) {
|
if (!ownerRole || !viewerRole || !editorRole) {
|
||||||
throw new Error('Required promptGroup roles not found. Run role seeding first.');
|
throw new Error('Required promptGroup roles not found. Run role seeding first.');
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
export * from './config';
|
export * from './config';
|
||||||
export * from './memory';
|
export * from './memory';
|
||||||
|
export * from './migration';
|
||||||
export * from './resources';
|
export * from './resources';
|
||||||
export * from './run';
|
export * from './run';
|
||||||
export * from './validation';
|
export * from './validation';
|
||||||
|
|
|
||||||
236
packages/api/src/agents/migration.ts
Normal file
236
packages/api/src/agents/migration.ts
Normal file
|
|
@ -0,0 +1,236 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import { AccessRoleIds, ResourceType, PrincipalType, Constants } from 'librechat-data-provider';
|
||||||
|
import type { AccessRoleMethods, IAgent } from '@librechat/data-schemas';
|
||||||
|
import type { Model } from 'mongoose';
|
||||||
|
|
||||||
|
const { GLOBAL_PROJECT_NAME } = Constants;
|
||||||
|
|
||||||
|
export interface MigrationCheckDbMethods {
|
||||||
|
findRoleByIdentifier: AccessRoleMethods['findRoleByIdentifier'];
|
||||||
|
getProjectByName: (
|
||||||
|
projectName: string,
|
||||||
|
fieldsToSelect?: string[] | null,
|
||||||
|
) => Promise<{
|
||||||
|
agentIds?: string[];
|
||||||
|
[key: string]: unknown;
|
||||||
|
} | null>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MigrationCheckParams {
|
||||||
|
db: MigrationCheckDbMethods;
|
||||||
|
AgentModel: Model<IAgent>;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface AgentMigrationData {
|
||||||
|
_id: string;
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
author: string;
|
||||||
|
isCollaborative: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MigrationCheckResult {
|
||||||
|
totalToMigrate: number;
|
||||||
|
globalEditAccess: number;
|
||||||
|
globalViewAccess: number;
|
||||||
|
privateAgents: number;
|
||||||
|
details?: {
|
||||||
|
globalEditAccess: Array<{ name: string; id: string }>;
|
||||||
|
globalViewAccess: Array<{ name: string; id: string }>;
|
||||||
|
privateAgents: Array<{ name: string; id: string }>;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if agents need to be migrated to the new permission system
|
||||||
|
* This performs a dry-run check similar to the migration script
|
||||||
|
*/
|
||||||
|
export async function checkAgentPermissionsMigration({
|
||||||
|
db,
|
||||||
|
AgentModel,
|
||||||
|
}: MigrationCheckParams): Promise<MigrationCheckResult> {
|
||||||
|
logger.debug('Checking if agent permissions migration is needed');
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Verify required roles exist
|
||||||
|
const ownerRole = await db.findRoleByIdentifier(AccessRoleIds.AGENT_OWNER);
|
||||||
|
const viewerRole = await db.findRoleByIdentifier(AccessRoleIds.AGENT_VIEWER);
|
||||||
|
const editorRole = await db.findRoleByIdentifier(AccessRoleIds.AGENT_EDITOR);
|
||||||
|
|
||||||
|
if (!ownerRole || !viewerRole || !editorRole) {
|
||||||
|
logger.warn(
|
||||||
|
'Required agent roles not found. Permission system may not be fully initialized.',
|
||||||
|
);
|
||||||
|
return {
|
||||||
|
totalToMigrate: 0,
|
||||||
|
globalEditAccess: 0,
|
||||||
|
globalViewAccess: 0,
|
||||||
|
privateAgents: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get global project agent IDs
|
||||||
|
const globalProject = await db.getProjectByName(GLOBAL_PROJECT_NAME, ['agentIds']);
|
||||||
|
const globalAgentIds = new Set(globalProject?.agentIds || []);
|
||||||
|
|
||||||
|
// Find agents without ACL entries (no batching for efficiency on startup)
|
||||||
|
const agentsToMigrate: AgentMigrationData[] = await AgentModel.aggregate([
|
||||||
|
{
|
||||||
|
$lookup: {
|
||||||
|
from: 'aclentries',
|
||||||
|
localField: '_id',
|
||||||
|
foreignField: 'resourceId',
|
||||||
|
as: 'aclEntries',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$addFields: {
|
||||||
|
userAclEntries: {
|
||||||
|
$filter: {
|
||||||
|
input: '$aclEntries',
|
||||||
|
as: 'aclEntry',
|
||||||
|
cond: {
|
||||||
|
$and: [
|
||||||
|
{ $eq: ['$$aclEntry.resourceType', ResourceType.AGENT] },
|
||||||
|
{ $eq: ['$$aclEntry.principalType', PrincipalType.USER] },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$match: {
|
||||||
|
author: { $exists: true, $ne: null },
|
||||||
|
userAclEntries: { $size: 0 },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$project: {
|
||||||
|
_id: 1,
|
||||||
|
id: 1,
|
||||||
|
name: 1,
|
||||||
|
author: 1,
|
||||||
|
isCollaborative: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const categories: {
|
||||||
|
globalEditAccess: AgentMigrationData[];
|
||||||
|
globalViewAccess: AgentMigrationData[];
|
||||||
|
privateAgents: AgentMigrationData[];
|
||||||
|
} = {
|
||||||
|
globalEditAccess: [],
|
||||||
|
globalViewAccess: [],
|
||||||
|
privateAgents: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
agentsToMigrate.forEach((agent) => {
|
||||||
|
const isGlobal = globalAgentIds.has(agent.id);
|
||||||
|
const isCollab = agent.isCollaborative;
|
||||||
|
|
||||||
|
if (isGlobal && isCollab) {
|
||||||
|
categories.globalEditAccess.push(agent);
|
||||||
|
} else if (isGlobal && !isCollab) {
|
||||||
|
categories.globalViewAccess.push(agent);
|
||||||
|
} else {
|
||||||
|
categories.privateAgents.push(agent);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const result: MigrationCheckResult = {
|
||||||
|
totalToMigrate: agentsToMigrate.length,
|
||||||
|
globalEditAccess: categories.globalEditAccess.length,
|
||||||
|
globalViewAccess: categories.globalViewAccess.length,
|
||||||
|
privateAgents: categories.privateAgents.length,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add details for debugging
|
||||||
|
if (agentsToMigrate.length > 0) {
|
||||||
|
result.details = {
|
||||||
|
globalEditAccess: categories.globalEditAccess.map((a) => ({
|
||||||
|
name: a.name,
|
||||||
|
id: a.id,
|
||||||
|
})),
|
||||||
|
globalViewAccess: categories.globalViewAccess.map((a) => ({
|
||||||
|
name: a.name,
|
||||||
|
id: a.id,
|
||||||
|
})),
|
||||||
|
privateAgents: categories.privateAgents.map((a) => ({
|
||||||
|
name: a.name,
|
||||||
|
id: a.id,
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug('Agent migration check completed', {
|
||||||
|
totalToMigrate: result.totalToMigrate,
|
||||||
|
globalEditAccess: result.globalEditAccess,
|
||||||
|
globalViewAccess: result.globalViewAccess,
|
||||||
|
privateAgents: result.privateAgents,
|
||||||
|
});
|
||||||
|
|
||||||
|
return result;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to check agent permissions migration', error);
|
||||||
|
// Return zero counts on error to avoid blocking startup
|
||||||
|
return {
|
||||||
|
totalToMigrate: 0,
|
||||||
|
globalEditAccess: 0,
|
||||||
|
globalViewAccess: 0,
|
||||||
|
privateAgents: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log migration warning to console if agents need migration
|
||||||
|
*/
|
||||||
|
export function logAgentMigrationWarning(result: MigrationCheckResult): void {
|
||||||
|
if (result.totalToMigrate === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a visible warning box
|
||||||
|
const border = '='.repeat(80);
|
||||||
|
const warning = [
|
||||||
|
'',
|
||||||
|
border,
|
||||||
|
' IMPORTANT: AGENT PERMISSIONS MIGRATION REQUIRED',
|
||||||
|
border,
|
||||||
|
'',
|
||||||
|
` Total agents to migrate: ${result.totalToMigrate}`,
|
||||||
|
` - Global Edit Access: ${result.globalEditAccess} agents`,
|
||||||
|
` - Global View Access: ${result.globalViewAccess} agents`,
|
||||||
|
` - Private Agents: ${result.privateAgents} agents`,
|
||||||
|
'',
|
||||||
|
' The new agent sharing system requires migrating existing agents.',
|
||||||
|
' Please run the following command to migrate your agents:',
|
||||||
|
'',
|
||||||
|
' npm run migrate:agent-permissions',
|
||||||
|
'',
|
||||||
|
' For a dry run (preview) of what will be migrated:',
|
||||||
|
'',
|
||||||
|
' npm run migrate:agent-permissions:dry-run',
|
||||||
|
'',
|
||||||
|
' This migration will:',
|
||||||
|
' 1. Grant owner permissions to agent authors',
|
||||||
|
' 2. Set appropriate public permissions based on global project status',
|
||||||
|
' 3. Preserve existing collaborative settings',
|
||||||
|
'',
|
||||||
|
border,
|
||||||
|
'',
|
||||||
|
];
|
||||||
|
|
||||||
|
// Use console methods directly for visibility
|
||||||
|
console.log('\n' + warning.join('\n') + '\n');
|
||||||
|
|
||||||
|
// Also log with logger for consistency
|
||||||
|
logger.warn('Agent permissions migration required', {
|
||||||
|
totalToMigrate: result.totalToMigrate,
|
||||||
|
globalEditAccess: result.globalEditAccess,
|
||||||
|
globalViewAccess: result.globalViewAccess,
|
||||||
|
privateAgents: result.privateAgents,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
export * from './content';
|
export * from './content';
|
||||||
export * from './prompts';
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ export * from './middleware';
|
||||||
export * from './memory';
|
export * from './memory';
|
||||||
/* Agents */
|
/* Agents */
|
||||||
export * from './agents';
|
export * from './agents';
|
||||||
|
/* Prompts */
|
||||||
|
export * from './prompts';
|
||||||
/* Endpoints */
|
/* Endpoints */
|
||||||
export * from './endpoints';
|
export * from './endpoints';
|
||||||
/* Files */
|
/* Files */
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,15 @@
|
||||||
import { logger } from '@librechat/data-schemas';
|
|
||||||
import mapValues from 'lodash/mapValues';
|
|
||||||
import pickBy from 'lodash/pickBy';
|
|
||||||
import pick from 'lodash/pick';
|
import pick from 'lodash/pick';
|
||||||
|
import pickBy from 'lodash/pickBy';
|
||||||
|
import mapValues from 'lodash/mapValues';
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import type { MCPConnection } from '~/mcp/connection';
|
||||||
import type { JsonSchemaType } from '~/types';
|
import type { JsonSchemaType } from '~/types';
|
||||||
import type * as t from '~/mcp/types';
|
import type * as t from '~/mcp/types';
|
||||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||||
import { type MCPConnection } from './connection';
|
|
||||||
import { processMCPEnv } from '~/utils';
|
import { processMCPEnv } from '~/utils';
|
||||||
import { CONSTANTS } from '~/mcp/enum';
|
import { CONSTANTS } from '~/mcp/enum';
|
||||||
|
|
||||||
type ParsedServerConfig = t.MCPOptions & {
|
|
||||||
url?: string;
|
|
||||||
requiresOAuth?: boolean;
|
|
||||||
oauthMetadata?: Record<string, unknown> | null;
|
|
||||||
capabilities?: string;
|
|
||||||
tools?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Manages MCP server configurations and metadata discovery.
|
* Manages MCP server configurations and metadata discovery.
|
||||||
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
|
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
|
||||||
|
|
@ -29,7 +21,7 @@ export class MCPServersRegistry {
|
||||||
private connections: ConnectionsRepository;
|
private connections: ConnectionsRepository;
|
||||||
|
|
||||||
public readonly rawConfigs: t.MCPServers;
|
public readonly rawConfigs: t.MCPServers;
|
||||||
public readonly parsedConfigs: Record<string, ParsedServerConfig>;
|
public readonly parsedConfigs: Record<string, t.ParsedServerConfig>;
|
||||||
|
|
||||||
public oauthServers: Set<string> | null = null;
|
public oauthServers: Set<string> | null = null;
|
||||||
public serverInstructions: Record<string, string> | null = null;
|
public serverInstructions: Record<string, string> | null = null;
|
||||||
|
|
@ -43,7 +35,7 @@ export class MCPServersRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
|
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
|
||||||
public async initialize() {
|
public async initialize(): Promise<void> {
|
||||||
if (this.initialized) return;
|
if (this.initialized) return;
|
||||||
this.initialized = true;
|
this.initialized = true;
|
||||||
|
|
||||||
|
|
@ -59,8 +51,8 @@ export class MCPServersRegistry {
|
||||||
this.connections.disconnectAll();
|
this.connections.disconnectAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetches all metadata for a single server in parallel
|
/** Fetches all metadata for a single server in parallel */
|
||||||
private async gatherServerInfo(serverName: string) {
|
private async gatherServerInfo(serverName: string): Promise<void> {
|
||||||
try {
|
try {
|
||||||
await this.fetchOAuthRequirement(serverName);
|
await this.fetchOAuthRequirement(serverName);
|
||||||
const config = this.parsedConfigs[serverName];
|
const config = this.parsedConfigs[serverName];
|
||||||
|
|
@ -82,8 +74,8 @@ export class MCPServersRegistry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets app-level server configs (startup enabled, non-OAuth servers)
|
/** Sets app-level server configs (startup enabled, non-OAuth servers) */
|
||||||
private setAppServerConfigs() {
|
private setAppServerConfigs(): void {
|
||||||
const appServers = Object.keys(
|
const appServers = Object.keys(
|
||||||
pickBy(
|
pickBy(
|
||||||
this.parsedConfigs,
|
this.parsedConfigs,
|
||||||
|
|
@ -93,8 +85,8 @@ export class MCPServersRegistry {
|
||||||
this.appServerConfigs = pick(this.rawConfigs, appServers);
|
this.appServerConfigs = pick(this.rawConfigs, appServers);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates set of server names that require OAuth authentication
|
/** Creates set of server names that require OAuth authentication */
|
||||||
private setOAuthServers() {
|
private setOAuthServers(): Set<string> {
|
||||||
if (this.oauthServers) return this.oauthServers;
|
if (this.oauthServers) return this.oauthServers;
|
||||||
this.oauthServers = new Set(
|
this.oauthServers = new Set(
|
||||||
Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)),
|
Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)),
|
||||||
|
|
@ -102,16 +94,16 @@ export class MCPServersRegistry {
|
||||||
return this.oauthServers;
|
return this.oauthServers;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collects server instructions from all configured servers
|
/** Collects server instructions from all configured servers */
|
||||||
private setServerInstructions() {
|
private setServerInstructions(): void {
|
||||||
this.serverInstructions = mapValues(
|
this.serverInstructions = mapValues(
|
||||||
pickBy(this.parsedConfigs, (config) => config.serverInstructions),
|
pickBy(this.parsedConfigs, (config) => config.serverInstructions),
|
||||||
(config) => config.serverInstructions as string,
|
(config) => config.serverInstructions as string,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builds registry of all available tool functions from loaded connections
|
/** Builds registry of all available tool functions from loaded connections */
|
||||||
private async setAppToolFunctions() {
|
private async setAppToolFunctions(): Promise<void> {
|
||||||
const connections = (await this.connections.getLoaded()).entries();
|
const connections = (await this.connections.getLoaded()).entries();
|
||||||
const allToolFunctions: t.LCAvailableTools = {};
|
const allToolFunctions: t.LCAvailableTools = {};
|
||||||
for (const [serverName, conn] of connections) {
|
for (const [serverName, conn] of connections) {
|
||||||
|
|
@ -125,12 +117,12 @@ export class MCPServersRegistry {
|
||||||
this.toolFunctions = allToolFunctions;
|
this.toolFunctions = allToolFunctions;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts server tools to LibreChat-compatible tool functions format
|
/** Converts server tools to LibreChat-compatible tool functions format */
|
||||||
private async getToolFunctions(
|
private async getToolFunctions(
|
||||||
serverName: string,
|
serverName: string,
|
||||||
conn: MCPConnection,
|
conn: MCPConnection,
|
||||||
): Promise<t.LCAvailableTools> {
|
): Promise<t.LCAvailableTools> {
|
||||||
const { tools } = await conn.client.listTools();
|
const { tools }: t.MCPToolListResponse = await conn.client.listTools();
|
||||||
|
|
||||||
const toolFunctions: t.LCAvailableTools = {};
|
const toolFunctions: t.LCAvailableTools = {};
|
||||||
tools.forEach((tool) => {
|
tools.forEach((tool) => {
|
||||||
|
|
@ -148,7 +140,7 @@ export class MCPServersRegistry {
|
||||||
return toolFunctions;
|
return toolFunctions;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determines if server requires OAuth if not already specified in the config
|
/** Determines if server requires OAuth if not already specified in the config */
|
||||||
private async fetchOAuthRequirement(serverName: string): Promise<boolean> {
|
private async fetchOAuthRequirement(serverName: string): Promise<boolean> {
|
||||||
const config = this.parsedConfigs[serverName];
|
const config = this.parsedConfigs[serverName];
|
||||||
if (config.requiresOAuth != null) return config.requiresOAuth;
|
if (config.requiresOAuth != null) return config.requiresOAuth;
|
||||||
|
|
@ -161,8 +153,8 @@ export class MCPServersRegistry {
|
||||||
return config.requiresOAuth;
|
return config.requiresOAuth;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieves server instructions from MCP server if enabled in the config
|
/** Retrieves server instructions from MCP server if enabled in the config */
|
||||||
private async fetchServerInstructions(serverName: string) {
|
private async fetchServerInstructions(serverName: string): Promise<void> {
|
||||||
const config = this.parsedConfigs[serverName];
|
const config = this.parsedConfigs[serverName];
|
||||||
if (!config.serverInstructions) return;
|
if (!config.serverInstructions) return;
|
||||||
if (typeof config.serverInstructions === 'string') return;
|
if (typeof config.serverInstructions === 'string') return;
|
||||||
|
|
@ -174,8 +166,8 @@ export class MCPServersRegistry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetches server capabilities and available tools list
|
/** Fetches server capabilities and available tools list */
|
||||||
private async fetchServerCapabilities(serverName: string) {
|
private async fetchServerCapabilities(serverName: string): Promise<void> {
|
||||||
const config = this.parsedConfigs[serverName];
|
const config = this.parsedConfigs[serverName];
|
||||||
const conn = await this.connections.get(serverName);
|
const conn = await this.connections.get(serverName);
|
||||||
const capabilities = conn.client.getServerCapabilities();
|
const capabilities = conn.client.getServerCapabilities();
|
||||||
|
|
@ -187,7 +179,7 @@ export class MCPServersRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logs server configuration summary after initialization
|
// Logs server configuration summary after initialization
|
||||||
private logUpdatedConfig(serverName: string) {
|
private logUpdatedConfig(serverName: string): void {
|
||||||
const prefix = this.prefix(serverName);
|
const prefix = this.prefix(serverName);
|
||||||
const config = this.parsedConfigs[serverName];
|
const config = this.parsedConfigs[serverName];
|
||||||
logger.info(`${prefix} -------------------------------------------------┐`);
|
logger.info(`${prefix} -------------------------------------------------┐`);
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import { readFileSync } from 'fs';
|
|
||||||
import { join } from 'path';
|
import { join } from 'path';
|
||||||
import { logger } from '@librechat/data-schemas';
|
import { readFileSync } from 'fs';
|
||||||
import { load as yamlLoad } from 'js-yaml';
|
import { load as yamlLoad } from 'js-yaml';
|
||||||
import { ConnectionsRepository } from '../ConnectionsRepository';
|
import { logger } from '@librechat/data-schemas';
|
||||||
import { MCPServersRegistry } from '../MCPServersRegistry';
|
import type { OAuthDetectionResult } from '~/mcp/oauth/detectOAuth';
|
||||||
|
import type * as t from '~/mcp/types';
|
||||||
|
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||||
|
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||||
import { MCPConnection } from '../connection';
|
import { MCPConnection } from '~/mcp/connection';
|
||||||
import type * as t from '../types';
|
|
||||||
|
|
||||||
// Mock external dependencies
|
// Mock external dependencies
|
||||||
jest.mock('../oauth/detectOAuth');
|
jest.mock('../oauth/detectOAuth');
|
||||||
|
|
@ -37,7 +38,7 @@ const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||||
|
|
||||||
describe('MCPServersRegistry - Initialize Function', () => {
|
describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
let rawConfigs: t.MCPServers;
|
let rawConfigs: t.MCPServers;
|
||||||
let expectedParsedConfigs: Record<string, any>;
|
let expectedParsedConfigs: Record<string, t.ParsedServerConfig>;
|
||||||
let mockConnectionsRepo: jest.Mocked<ConnectionsRepository>;
|
let mockConnectionsRepo: jest.Mocked<ConnectionsRepository>;
|
||||||
let mockConnections: Map<string, jest.Mocked<MCPConnection>>;
|
let mockConnections: Map<string, jest.Mocked<MCPConnection>>;
|
||||||
|
|
||||||
|
|
@ -49,7 +50,7 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers;
|
rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers;
|
||||||
expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record<
|
expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record<
|
||||||
string,
|
string,
|
||||||
any
|
t.ParsedServerConfig
|
||||||
>;
|
>;
|
||||||
|
|
||||||
// Setup mock connections
|
// Setup mock connections
|
||||||
|
|
@ -57,12 +58,13 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
const serverNames = Object.keys(rawConfigs);
|
const serverNames = Object.keys(rawConfigs);
|
||||||
|
|
||||||
serverNames.forEach((serverName) => {
|
serverNames.forEach((serverName) => {
|
||||||
|
const mockClient = {
|
||||||
|
listTools: jest.fn(),
|
||||||
|
getInstructions: jest.fn(),
|
||||||
|
getServerCapabilities: jest.fn(),
|
||||||
|
};
|
||||||
const mockConnection = {
|
const mockConnection = {
|
||||||
client: {
|
client: mockClient,
|
||||||
listTools: jest.fn(),
|
|
||||||
getInstructions: jest.fn(),
|
|
||||||
getServerCapabilities: jest.fn(),
|
|
||||||
},
|
|
||||||
} as unknown as jest.Mocked<MCPConnection>;
|
} as unknown as jest.Mocked<MCPConnection>;
|
||||||
|
|
||||||
// Setup mock responses based on expected configs
|
// Setup mock responses based on expected configs
|
||||||
|
|
@ -75,30 +77,32 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
name,
|
name,
|
||||||
description: `Description for ${name}`,
|
description: `Description for ${name}`,
|
||||||
inputSchema: {
|
inputSchema: {
|
||||||
type: 'object',
|
type: 'object' as const,
|
||||||
properties: {
|
properties: {
|
||||||
input: { type: 'string' },
|
input: { type: 'string' },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
mockConnection.client.listTools.mockResolvedValue({ tools });
|
(mockClient.listTools as jest.Mock).mockResolvedValue({ tools });
|
||||||
} else {
|
} else {
|
||||||
mockConnection.client.listTools.mockResolvedValue({ tools: [] });
|
(mockClient.listTools as jest.Mock).mockResolvedValue({ tools: [] });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mock getInstructions response
|
// Mock getInstructions response
|
||||||
if (expectedConfig.serverInstructions) {
|
if (expectedConfig.serverInstructions) {
|
||||||
mockConnection.client.getInstructions.mockReturnValue(expectedConfig.serverInstructions);
|
(mockClient.getInstructions as jest.Mock).mockReturnValue(
|
||||||
|
expectedConfig.serverInstructions as string,
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
mockConnection.client.getInstructions.mockReturnValue(null);
|
(mockClient.getInstructions as jest.Mock).mockReturnValue(undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mock getServerCapabilities response
|
// Mock getServerCapabilities response
|
||||||
if (expectedConfig.capabilities) {
|
if (expectedConfig.capabilities) {
|
||||||
const capabilities = JSON.parse(expectedConfig.capabilities);
|
const capabilities = JSON.parse(expectedConfig.capabilities) as Record<string, unknown>;
|
||||||
mockConnection.client.getServerCapabilities.mockReturnValue(capabilities);
|
(mockClient.getServerCapabilities as jest.Mock).mockReturnValue(capabilities);
|
||||||
} else {
|
} else {
|
||||||
mockConnection.client.getServerCapabilities.mockReturnValue(null);
|
(mockClient.getServerCapabilities as jest.Mock).mockReturnValue(undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
mockConnections.set(serverName, mockConnection);
|
mockConnections.set(serverName, mockConnection);
|
||||||
|
|
@ -111,9 +115,13 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
disconnectAll: jest.fn(),
|
disconnectAll: jest.fn(),
|
||||||
} as unknown as jest.Mocked<ConnectionsRepository>;
|
} as unknown as jest.Mocked<ConnectionsRepository>;
|
||||||
|
|
||||||
mockConnectionsRepo.get.mockImplementation((serverName: string) =>
|
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||||
Promise.resolve(mockConnections.get(serverName)!),
|
const connection = mockConnections.get(serverName);
|
||||||
);
|
if (!connection) {
|
||||||
|
throw new Error(`Connection not found for server: ${serverName}`);
|
||||||
|
}
|
||||||
|
return Promise.resolve(connection);
|
||||||
|
});
|
||||||
|
|
||||||
mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections);
|
mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections);
|
||||||
|
|
||||||
|
|
@ -121,9 +129,10 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
|
|
||||||
// Setup OAuth detection mock with deterministic results
|
// Setup OAuth detection mock with deterministic results
|
||||||
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||||
const oauthResults: Record<string, any> = {
|
const oauthResults: Record<string, OAuthDetectionResult> = {
|
||||||
'https://api.github.com/mcp': {
|
'https://api.github.com/mcp': {
|
||||||
requiresOAuth: true,
|
requiresOAuth: true,
|
||||||
|
method: 'protected-resource-metadata',
|
||||||
metadata: {
|
metadata: {
|
||||||
authorization_url: 'https://github.com/login/oauth/authorize',
|
authorization_url: 'https://github.com/login/oauth/authorize',
|
||||||
token_url: 'https://github.com/login/oauth/access_token',
|
token_url: 'https://github.com/login/oauth/access_token',
|
||||||
|
|
@ -131,15 +140,19 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
},
|
},
|
||||||
'https://api.disabled.com/mcp': {
|
'https://api.disabled.com/mcp': {
|
||||||
requiresOAuth: false,
|
requiresOAuth: false,
|
||||||
|
method: 'no-metadata-found',
|
||||||
metadata: null,
|
metadata: null,
|
||||||
},
|
},
|
||||||
'https://api.public.com/mcp': {
|
'https://api.public.com/mcp': {
|
||||||
requiresOAuth: false,
|
requiresOAuth: false,
|
||||||
|
method: 'no-metadata-found',
|
||||||
metadata: null,
|
metadata: null,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
return Promise.resolve(oauthResults[url] || { requiresOAuth: false, metadata: null });
|
return Promise.resolve(
|
||||||
|
oauthResults[url] || { requiresOAuth: false, method: 'no-metadata-found', metadata: null },
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Clear all mocks
|
// Clear all mocks
|
||||||
|
|
|
||||||
|
|
@ -105,3 +105,11 @@ export type FormattedToolResponse = [
|
||||||
string | FormattedContent[],
|
string | FormattedContent[],
|
||||||
{ content: FormattedContent[] } | undefined,
|
{ content: FormattedContent[] } | undefined,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
export type ParsedServerConfig = MCPOptions & {
|
||||||
|
url?: string;
|
||||||
|
requiresOAuth?: boolean;
|
||||||
|
oauthMetadata?: Record<string, unknown> | null;
|
||||||
|
capabilities?: string;
|
||||||
|
tools?: string;
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ jest.mock('@librechat/data-schemas', () => ({
|
||||||
describe('ErrorController', () => {
|
describe('ErrorController', () => {
|
||||||
let mockReq: Request;
|
let mockReq: Request;
|
||||||
let mockRes: Response;
|
let mockRes: Response;
|
||||||
|
let mockNext: jest.Mock;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockReq = {
|
mockReq = {
|
||||||
|
|
@ -25,6 +26,7 @@ describe('ErrorController', () => {
|
||||||
send: jest.fn(),
|
send: jest.fn(),
|
||||||
} as unknown as Response;
|
} as unknown as Response;
|
||||||
(logger.error as jest.Mock).mockClear();
|
(logger.error as jest.Mock).mockClear();
|
||||||
|
mockNext = jest.fn();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('ValidationError handling', () => {
|
describe('ValidationError handling', () => {
|
||||||
|
|
@ -37,7 +39,7 @@ describe('ErrorController', () => {
|
||||||
},
|
},
|
||||||
} as ValidationError;
|
} as ValidationError;
|
||||||
|
|
||||||
ErrorController(validationError, mockReq, mockRes);
|
ErrorController(validationError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith({
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
|
@ -57,7 +59,7 @@ describe('ErrorController', () => {
|
||||||
},
|
},
|
||||||
} as ValidationError;
|
} as ValidationError;
|
||||||
|
|
||||||
ErrorController(validationError, mockReq, mockRes);
|
ErrorController(validationError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith({
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
|
@ -73,7 +75,7 @@ describe('ErrorController', () => {
|
||||||
errors: {},
|
errors: {},
|
||||||
} as ValidationError;
|
} as ValidationError;
|
||||||
|
|
||||||
ErrorController(validationError, mockReq, mockRes);
|
ErrorController(validationError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith({
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
|
@ -94,7 +96,7 @@ describe('ErrorController', () => {
|
||||||
'E11000 duplicate key error collection: test.users index: email_1 dup key: { email: "test@example.com" }',
|
'E11000 duplicate key error collection: test.users index: email_1 dup key: { email: "test@example.com" }',
|
||||||
} as MongoServerError;
|
} as MongoServerError;
|
||||||
|
|
||||||
ErrorController(duplicateKeyError, mockReq, mockRes);
|
ErrorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith({
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
|
@ -116,7 +118,7 @@ describe('ErrorController', () => {
|
||||||
'E11000 duplicate key error collection: test.users index: email_1 dup key: { email: "test@example.com" }',
|
'E11000 duplicate key error collection: test.users index: email_1 dup key: { email: "test@example.com" }',
|
||||||
} as MongoServerError;
|
} as MongoServerError;
|
||||||
|
|
||||||
ErrorController(duplicateKeyError, mockReq, mockRes);
|
ErrorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith({
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
|
@ -138,7 +140,7 @@ describe('ErrorController', () => {
|
||||||
'E11000 duplicate key error collection: test.users index: email_1 dup key: { email: "test@example.com" }',
|
'E11000 duplicate key error collection: test.users index: email_1 dup key: { email: "test@example.com" }',
|
||||||
} as MongoServerError;
|
} as MongoServerError;
|
||||||
|
|
||||||
ErrorController(duplicateKeyError, mockReq, mockRes);
|
ErrorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith({
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
|
@ -155,7 +157,7 @@ describe('ErrorController', () => {
|
||||||
body: 'Invalid JSON syntax',
|
body: 'Invalid JSON syntax',
|
||||||
} as CustomError;
|
} as CustomError;
|
||||||
|
|
||||||
ErrorController(syntaxError, mockReq, mockRes);
|
ErrorController(syntaxError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
|
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
|
||||||
|
|
@ -167,7 +169,7 @@ describe('ErrorController', () => {
|
||||||
body: { error: 'Unprocessable entity' },
|
body: { error: 'Unprocessable entity' },
|
||||||
} as CustomError;
|
} as CustomError;
|
||||||
|
|
||||||
ErrorController(customError, mockReq, mockRes);
|
ErrorController(customError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(422);
|
expect(mockRes.status).toHaveBeenCalledWith(422);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
|
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
|
||||||
|
|
@ -178,7 +180,7 @@ describe('ErrorController', () => {
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
} as CustomError;
|
} as CustomError;
|
||||||
|
|
||||||
ErrorController(partialError, mockReq, mockRes);
|
ErrorController(partialError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
|
@ -189,7 +191,7 @@ describe('ErrorController', () => {
|
||||||
body: 'Some error message',
|
body: 'Some error message',
|
||||||
} as CustomError;
|
} as CustomError;
|
||||||
|
|
||||||
ErrorController(partialError, mockReq, mockRes);
|
ErrorController(partialError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
|
@ -200,7 +202,7 @@ describe('ErrorController', () => {
|
||||||
it('should handle unknown errors', () => {
|
it('should handle unknown errors', () => {
|
||||||
const unknownError = new Error('Some unknown error');
|
const unknownError = new Error('Some unknown error');
|
||||||
|
|
||||||
ErrorController(unknownError, mockReq, mockRes);
|
ErrorController(unknownError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
|
@ -213,7 +215,7 @@ describe('ErrorController', () => {
|
||||||
message: 'Some MongoDB error',
|
message: 'Some MongoDB error',
|
||||||
} as MongoServerError;
|
} as MongoServerError;
|
||||||
|
|
||||||
ErrorController(mongoError, mockReq, mockRes);
|
ErrorController(mongoError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
|
@ -223,7 +225,7 @@ describe('ErrorController', () => {
|
||||||
it('should handle generic errors', () => {
|
it('should handle generic errors', () => {
|
||||||
const genericError = new Error('Test error');
|
const genericError = new Error('Test error');
|
||||||
|
|
||||||
ErrorController(genericError, mockReq, mockRes);
|
ErrorController(genericError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
|
@ -254,7 +256,7 @@ describe('ErrorController', () => {
|
||||||
|
|
||||||
const testError = new Error('Test error');
|
const testError = new Error('Test error');
|
||||||
|
|
||||||
ErrorController(testError, mockReq, freshMockRes);
|
ErrorController(testError, mockReq, freshMockRes, mockNext);
|
||||||
|
|
||||||
expect(freshMockRes.status).toHaveBeenCalledWith(500);
|
expect(freshMockRes.status).toHaveBeenCalledWith(500);
|
||||||
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||||
|
|
|
||||||
2
packages/api/src/prompts/index.ts
Normal file
2
packages/api/src/prompts/index.ts
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
export * from './format';
|
||||||
|
export * from './migration';
|
||||||
223
packages/api/src/prompts/migration.ts
Normal file
223
packages/api/src/prompts/migration.ts
Normal file
|
|
@ -0,0 +1,223 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import { AccessRoleIds, ResourceType, PrincipalType, Constants } from 'librechat-data-provider';
|
||||||
|
import type { AccessRoleMethods, IPromptGroupDocument } from '@librechat/data-schemas';
|
||||||
|
import type { Model } from 'mongoose';
|
||||||
|
|
||||||
|
const { GLOBAL_PROJECT_NAME } = Constants;
|
||||||
|
|
||||||
|
export interface PromptMigrationCheckDbMethods {
|
||||||
|
findRoleByIdentifier: AccessRoleMethods['findRoleByIdentifier'];
|
||||||
|
getProjectByName: (
|
||||||
|
projectName: string,
|
||||||
|
fieldsToSelect?: string[] | null,
|
||||||
|
) => Promise<{
|
||||||
|
promptGroupIds?: string[];
|
||||||
|
[key: string]: unknown;
|
||||||
|
} | null>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PromptMigrationCheckParams {
|
||||||
|
db: PromptMigrationCheckDbMethods;
|
||||||
|
PromptGroupModel: Model<IPromptGroupDocument>;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface PromptGroupMigrationData {
|
||||||
|
_id: string;
|
||||||
|
name: string;
|
||||||
|
author: string;
|
||||||
|
authorName?: string;
|
||||||
|
category?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PromptMigrationCheckResult {
|
||||||
|
totalToMigrate: number;
|
||||||
|
globalViewAccess: number;
|
||||||
|
privateGroups: number;
|
||||||
|
details?: {
|
||||||
|
globalViewAccess: Array<{ name: string; _id: string; category: string }>;
|
||||||
|
privateGroups: Array<{ name: string; _id: string; category: string }>;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if prompt groups need to be migrated to the new permission system
|
||||||
|
* This performs a dry-run check similar to the migration script
|
||||||
|
*/
|
||||||
|
export async function checkPromptPermissionsMigration({
|
||||||
|
db,
|
||||||
|
PromptGroupModel,
|
||||||
|
}: PromptMigrationCheckParams): Promise<PromptMigrationCheckResult> {
|
||||||
|
logger.debug('Checking if prompt permissions migration is needed');
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Verify required roles exist
|
||||||
|
const ownerRole = await db.findRoleByIdentifier(AccessRoleIds.PROMPTGROUP_OWNER);
|
||||||
|
const viewerRole = await db.findRoleByIdentifier(AccessRoleIds.PROMPTGROUP_VIEWER);
|
||||||
|
const editorRole = await db.findRoleByIdentifier(AccessRoleIds.PROMPTGROUP_EDITOR);
|
||||||
|
|
||||||
|
if (!ownerRole || !viewerRole || !editorRole) {
|
||||||
|
logger.warn(
|
||||||
|
'Required promptGroup roles not found. Permission system may not be fully initialized.',
|
||||||
|
);
|
||||||
|
return {
|
||||||
|
totalToMigrate: 0,
|
||||||
|
globalViewAccess: 0,
|
||||||
|
privateGroups: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get global project prompt group IDs
|
||||||
|
const globalProject = await db.getProjectByName(GLOBAL_PROJECT_NAME, ['promptGroupIds']);
|
||||||
|
const globalPromptGroupIds = new Set(
|
||||||
|
(globalProject?.promptGroupIds || []).map((id) => id.toString()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Find promptGroups without ACL entries (no batching for efficiency on startup)
|
||||||
|
const promptGroupsToMigrate: PromptGroupMigrationData[] = await PromptGroupModel.aggregate([
|
||||||
|
{
|
||||||
|
$lookup: {
|
||||||
|
from: 'aclentries',
|
||||||
|
localField: '_id',
|
||||||
|
foreignField: 'resourceId',
|
||||||
|
as: 'aclEntries',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$addFields: {
|
||||||
|
promptGroupAclEntries: {
|
||||||
|
$filter: {
|
||||||
|
input: '$aclEntries',
|
||||||
|
as: 'aclEntry',
|
||||||
|
cond: {
|
||||||
|
$and: [
|
||||||
|
{ $eq: ['$$aclEntry.resourceType', ResourceType.PROMPTGROUP] },
|
||||||
|
{ $eq: ['$$aclEntry.principalType', PrincipalType.USER] },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$match: {
|
||||||
|
author: { $exists: true, $ne: null },
|
||||||
|
promptGroupAclEntries: { $size: 0 },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$project: {
|
||||||
|
_id: 1,
|
||||||
|
name: 1,
|
||||||
|
author: 1,
|
||||||
|
authorName: 1,
|
||||||
|
category: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const categories: {
|
||||||
|
globalViewAccess: PromptGroupMigrationData[];
|
||||||
|
privateGroups: PromptGroupMigrationData[];
|
||||||
|
} = {
|
||||||
|
globalViewAccess: [],
|
||||||
|
privateGroups: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
promptGroupsToMigrate.forEach((group) => {
|
||||||
|
const isGlobalGroup = globalPromptGroupIds.has(group._id.toString());
|
||||||
|
|
||||||
|
if (isGlobalGroup) {
|
||||||
|
categories.globalViewAccess.push(group);
|
||||||
|
} else {
|
||||||
|
categories.privateGroups.push(group);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const result: PromptMigrationCheckResult = {
|
||||||
|
totalToMigrate: promptGroupsToMigrate.length,
|
||||||
|
globalViewAccess: categories.globalViewAccess.length,
|
||||||
|
privateGroups: categories.privateGroups.length,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add details for debugging
|
||||||
|
if (promptGroupsToMigrate.length > 0) {
|
||||||
|
result.details = {
|
||||||
|
globalViewAccess: categories.globalViewAccess.map((g) => ({
|
||||||
|
name: g.name,
|
||||||
|
_id: g._id.toString(),
|
||||||
|
category: g.category || 'uncategorized',
|
||||||
|
})),
|
||||||
|
privateGroups: categories.privateGroups.map((g) => ({
|
||||||
|
name: g.name,
|
||||||
|
_id: g._id.toString(),
|
||||||
|
category: g.category || 'uncategorized',
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug('Prompt migration check completed', {
|
||||||
|
totalToMigrate: result.totalToMigrate,
|
||||||
|
globalViewAccess: result.globalViewAccess,
|
||||||
|
privateGroups: result.privateGroups,
|
||||||
|
});
|
||||||
|
|
||||||
|
return result;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to check prompt permissions migration', error);
|
||||||
|
// Return zero counts on error to avoid blocking startup
|
||||||
|
return {
|
||||||
|
totalToMigrate: 0,
|
||||||
|
globalViewAccess: 0,
|
||||||
|
privateGroups: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log migration warning to console if prompt groups need migration
|
||||||
|
*/
|
||||||
|
export function logPromptMigrationWarning(result: PromptMigrationCheckResult): void {
|
||||||
|
if (result.totalToMigrate === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a visible warning box
|
||||||
|
const border = '='.repeat(80);
|
||||||
|
const warning = [
|
||||||
|
'',
|
||||||
|
border,
|
||||||
|
' IMPORTANT: PROMPT PERMISSIONS MIGRATION REQUIRED',
|
||||||
|
border,
|
||||||
|
'',
|
||||||
|
` Total prompt groups to migrate: ${result.totalToMigrate}`,
|
||||||
|
` - Global View Access: ${result.globalViewAccess} prompt groups`,
|
||||||
|
` - Private Prompt Groups: ${result.privateGroups} prompt groups`,
|
||||||
|
'',
|
||||||
|
' The new prompt sharing system requires migrating existing prompt groups.',
|
||||||
|
' Please run the following command to migrate your prompts:',
|
||||||
|
'',
|
||||||
|
' npm run migrate:prompt-permissions',
|
||||||
|
'',
|
||||||
|
' For a dry run (preview) of what will be migrated:',
|
||||||
|
'',
|
||||||
|
' npm run migrate:prompt-permissions:dry-run',
|
||||||
|
'',
|
||||||
|
' This migration will:',
|
||||||
|
' 1. Grant owner permissions to prompt authors',
|
||||||
|
' 2. Set public view permissions for prompts in the global project',
|
||||||
|
' 3. Keep private prompts accessible only to their authors',
|
||||||
|
'',
|
||||||
|
border,
|
||||||
|
'',
|
||||||
|
];
|
||||||
|
|
||||||
|
// Use console methods directly for visibility
|
||||||
|
console.log('\n' + warning.join('\n') + '\n');
|
||||||
|
|
||||||
|
// Also log with logger for consistency
|
||||||
|
logger.warn('Prompt permissions migration required', {
|
||||||
|
totalToMigrate: result.totalToMigrate,
|
||||||
|
globalViewAccess: result.globalViewAccess,
|
||||||
|
privateGroups: result.privateGroups,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
@ -1,18 +1,19 @@
|
||||||
import { createUserMethods, type UserMethods } from './user';
|
|
||||||
import { createSessionMethods, type SessionMethods } from './session';
|
import { createSessionMethods, type SessionMethods } from './session';
|
||||||
import { createTokenMethods, type TokenMethods } from './token';
|
import { createTokenMethods, type TokenMethods } from './token';
|
||||||
import { createRoleMethods, type RoleMethods } from './role';
|
import { createRoleMethods, type RoleMethods } from './role';
|
||||||
|
import { createUserMethods, type UserMethods } from './user';
|
||||||
/* Memories */
|
/* Memories */
|
||||||
import { createMemoryMethods, type MemoryMethods } from './memory';
|
import { createMemoryMethods, type MemoryMethods } from './memory';
|
||||||
/* Agent Categories */
|
/* Agent Categories */
|
||||||
import { createAgentCategoryMethods, type AgentCategoryMethods } from './agentCategory';
|
import { createAgentCategoryMethods, type AgentCategoryMethods } from './agentCategory';
|
||||||
|
/* Plugin Auth */
|
||||||
|
import { createPluginAuthMethods, type PluginAuthMethods } from './pluginAuth';
|
||||||
/* Permissions */
|
/* Permissions */
|
||||||
import { createAccessRoleMethods, type AccessRoleMethods } from './accessRole';
|
import { createAccessRoleMethods, type AccessRoleMethods } from './accessRole';
|
||||||
import { createUserGroupMethods, type UserGroupMethods } from './userGroup';
|
import { createUserGroupMethods, type UserGroupMethods } from './userGroup';
|
||||||
import { createAclEntryMethods, type AclEntryMethods } from './aclEntry';
|
import { createAclEntryMethods, type AclEntryMethods } from './aclEntry';
|
||||||
import { createGroupMethods, type GroupMethods } from './group';
|
import { createGroupMethods, type GroupMethods } from './group';
|
||||||
import { createShareMethods, type ShareMethods } from './share';
|
import { createShareMethods, type ShareMethods } from './share';
|
||||||
import { createPluginAuthMethods, type PluginAuthMethods } from './pluginAuth';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates all database methods for all collections
|
* Creates all database methods for all collections
|
||||||
|
|
@ -34,16 +35,29 @@ export function createMethods(mongoose: typeof import('mongoose')) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export type { MemoryMethods, ShareMethods, TokenMethods, PluginAuthMethods };
|
export type {
|
||||||
|
UserMethods,
|
||||||
|
SessionMethods,
|
||||||
|
TokenMethods,
|
||||||
|
RoleMethods,
|
||||||
|
MemoryMethods,
|
||||||
|
AgentCategoryMethods,
|
||||||
|
UserGroupMethods,
|
||||||
|
AclEntryMethods,
|
||||||
|
GroupMethods,
|
||||||
|
ShareMethods,
|
||||||
|
AccessRoleMethods,
|
||||||
|
PluginAuthMethods,
|
||||||
|
};
|
||||||
export type AllMethods = UserMethods &
|
export type AllMethods = UserMethods &
|
||||||
SessionMethods &
|
SessionMethods &
|
||||||
TokenMethods &
|
TokenMethods &
|
||||||
RoleMethods &
|
RoleMethods &
|
||||||
MemoryMethods &
|
MemoryMethods &
|
||||||
AgentCategoryMethods &
|
AgentCategoryMethods &
|
||||||
AccessRoleMethods &
|
|
||||||
UserGroupMethods &
|
UserGroupMethods &
|
||||||
AclEntryMethods &
|
AclEntryMethods &
|
||||||
GroupMethods &
|
GroupMethods &
|
||||||
ShareMethods &
|
ShareMethods &
|
||||||
|
AccessRoleMethods &
|
||||||
PluginAuthMethods;
|
PluginAuthMethods;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue