🔐 feat: Implement Allowed Action Domains (#4964)

* chore: RequestExecutor typing

* feat: allowed action domains

* fix: rename TAgentsEndpoint to TAssistantEndpoint in typedefs

* chore: update librechat-data-provider version to 0.7.62
This commit is contained in:
Danny Avila 2024-12-12 12:52:42 -05:00 committed by GitHub
parent e82af236bc
commit 69bd8e3644
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 364 additions and 97 deletions

View file

@ -1,4 +1,4 @@
const { isDomainAllowed } = require('~/server/services/AuthService');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { logger } = require('~/config');
/**
@ -14,7 +14,7 @@ const { logger } = require('~/config');
*/
const checkDomainAllowed = async (req, res, next = () => {}) => {
const email = req?.user?.email;
if (email && !(await isDomainAllowed(email))) {
if (email && !(await isEmailDomainAllowed(email))) {
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
return res.redirect('/login');
} else {

View file

@ -3,6 +3,7 @@ const { nanoid } = require('nanoid');
const { actionDelimiter } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { getAgent, updateAgent } = require('~/models/Agent');
const { logger } = require('~/config');
@ -42,6 +43,10 @@ router.post('/:agent_id', async (req, res) => {
}
let metadata = await encryptMetadata(_metadata);
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
if (!isDomainAllowed) {
return res.status(400).json({ message: 'Domain not allowed' });
}
let { domain } = metadata;
domain = await domainParser(req, domain, true);

View file

@ -1,10 +1,11 @@
const express = require('express');
const { nanoid } = require('nanoid');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { logger } = require('~/config');
const router = express.Router();
@ -29,6 +30,10 @@ router.post('/:assistant_id', async (req, res) => {
}
let metadata = await encryptMetadata(_metadata);
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
if (!isDomainAllowed) {
return res.status(400).json({ message: 'Domain not allowed' });
}
let { domain } = metadata;
domain = await domainParser(req, domain, true);

View file

@ -7,6 +7,7 @@ const {
actionDomainSeparator,
} = require('librechat-data-provider');
const { tool } = require('@langchain/core/tools');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions, deleteActions } = require('~/models/Action');
const { deleteAssistant } = require('~/models/Assistant');
@ -122,6 +123,10 @@ async function loadActionSets(searchParams) {
*/
async function createActionTool({ action, requestBuilder, zodSchema, name, description }) {
action.metadata = await decryptMetadata(action.metadata);
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
return null;
}
/** @type {(toolInput: Object | string) => Promise<unknown>} */
const _call = async (toolInput) => {
try {

View file

@ -2,6 +2,9 @@ const { Constants, EModelEndpoint, actionDomainSeparator } = require('librechat-
const { domainParser } = require('./ActionService');
jest.mock('keyv');
jest.mock('~/server/services/Config', () => ({
getCustomConfig: jest.fn(),
}));
const globalCache = {};
jest.mock('~/cache/getLogStores', () => {

View file

@ -12,9 +12,9 @@ const {
} = require('~/models/userMethods');
const { createToken, findToken, deleteTokens, Session } = require('~/models');
const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { registerSchema } = require('~/strategies/validators');
const { hashToken } = require('~/server/utils/crypto');
const isDomainAllowed = require('./isDomainAllowed');
const { logger } = require('~/config');
const domains = {
@ -165,7 +165,7 @@ const registerUser = async (user, additionalData = {}) => {
return { status: 200, message: genericVerificationMessage };
}
if (!(await isDomainAllowed(email))) {
if (!(await isEmailDomainAllowed(email))) {
const errorMessage =
'The email address provided cannot be used. Please use a different email address.';
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
@ -422,7 +422,6 @@ module.exports = {
registerUser,
setAuthTokens,
resetPassword,
isDomainAllowed,
requestPasswordReset,
resendVerificationEmail,
};

View file

@ -5,6 +5,7 @@ const { tool: toolFn, Tool } = require('@langchain/core/tools');
const { Calculator } = require('@langchain/community/tools/calculator');
const {
Tools,
ErrorTypes,
ContentTypes,
imageGenTools,
actionDelimiter,
@ -327,6 +328,12 @@ async function processRequiredActions(client, requiredActions) {
}
tool = await createActionTool({ action: actionSet, requestBuilder });
if (!tool) {
logger.warn(
`Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
isActionTool = !!tool;
ActionToolMap[currentAction.tool] = tool;
}
@ -464,6 +471,12 @@ async function loadAgentTools({ req, agent_id, tools, tool_resources, openAIApiK
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent_id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
}

View file

@ -0,0 +1,109 @@
const { getCustomConfig } = require('~/server/services/Config');
/**
* @param {string} email
* @returns {Promise<boolean>}
*/
async function isEmailDomainAllowed(email) {
if (!email) {
return false;
}
const domain = email.split('@')[1];
if (!domain) {
return false;
}
const customConfig = await getCustomConfig();
if (!customConfig) {
return true;
} else if (!customConfig?.registration?.allowedDomains) {
return true;
}
return customConfig.registration.allowedDomains.includes(domain);
}
/**
* Normalizes a domain string
* @param {string} domain
* @returns {string|null}
*/
/**
* Normalizes a domain string. If the domain is invalid, returns null.
* Normalized === lowercase, trimmed, and protocol added if missing.
* @param {string} domain
* @returns {string|null}
*/
function normalizeDomain(domain) {
try {
let normalizedDomain = domain.toLowerCase().trim();
// Early return for obviously invalid formats
if (normalizedDomain === 'http://' || normalizedDomain === 'https://') {
return null;
}
// If it's not already a URL, make it one
if (!normalizedDomain.startsWith('http://') && !normalizedDomain.startsWith('https://')) {
normalizedDomain = `https://${normalizedDomain}`;
}
const url = new URL(normalizedDomain);
// Additional validation that hostname isn't just protocol
if (!url.hostname || url.hostname === 'http:' || url.hostname === 'https:') {
return null;
}
return url.hostname.replace(/^www\./i, '');
} catch {
return null;
}
}
/**
* Checks if the given domain is allowed. If no restrictions are set, allows all domains.
* @param {string} [domain]
* @returns {Promise<boolean>}
*/
async function isActionDomainAllowed(domain) {
if (!domain || typeof domain !== 'string') {
return false;
}
const customConfig = await getCustomConfig();
const allowedDomains = customConfig?.actions?.allowedDomains;
if (!Array.isArray(allowedDomains) || !allowedDomains.length) {
return true;
}
const normalizedInputDomain = normalizeDomain(domain);
if (!normalizedInputDomain) {
return false;
}
for (const allowedDomain of allowedDomains) {
const normalizedAllowedDomain = normalizeDomain(allowedDomain);
if (!normalizedAllowedDomain) {
continue;
}
if (normalizedAllowedDomain.startsWith('*.')) {
const baseDomain = normalizedAllowedDomain.slice(2);
if (
normalizedInputDomain === baseDomain ||
normalizedInputDomain.endsWith(`.${baseDomain}`)
) {
return true;
}
} else if (normalizedInputDomain === normalizedAllowedDomain) {
return true;
}
}
return false;
}
module.exports = { isEmailDomainAllowed, isActionDomainAllowed };

View file

@ -0,0 +1,193 @@
const { isEmailDomainAllowed, isActionDomainAllowed } = require('~/server/services/domains');
const { getCustomConfig } = require('~/server/services/Config');
jest.mock('~/server/services/Config', () => ({
getCustomConfig: jest.fn(),
}));
describe('isEmailDomainAllowed', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should return false if email is falsy', async () => {
const email = '';
const result = await isEmailDomainAllowed(email);
expect(result).toBe(false);
});
it('should return false if domain is not present in the email', async () => {
const email = 'test';
const result = await isEmailDomainAllowed(email);
expect(result).toBe(false);
});
it('should return true if customConfig is not available', async () => {
const email = 'test@domain1.com';
getCustomConfig.mockResolvedValue(null);
const result = await isEmailDomainAllowed(email);
expect(result).toBe(true);
});
it('should return true if allowedDomains is not defined in customConfig', async () => {
const email = 'test@domain1.com';
getCustomConfig.mockResolvedValue({});
const result = await isEmailDomainAllowed(email);
expect(result).toBe(true);
});
it('should return true if domain is included in the allowedDomains', async () => {
const email = 'user@domain1.com';
getCustomConfig.mockResolvedValue({
registration: {
allowedDomains: ['domain1.com', 'domain2.com'],
},
});
const result = await isEmailDomainAllowed(email);
expect(result).toBe(true);
});
it('should return false if domain is not included in the allowedDomains', async () => {
const email = 'user@domain3.com';
getCustomConfig.mockResolvedValue({
registration: {
allowedDomains: ['domain1.com', 'domain2.com'],
},
});
const result = await isEmailDomainAllowed(email);
expect(result).toBe(false);
});
});
describe('isActionDomainAllowed', () => {
afterEach(() => {
jest.clearAllMocks();
});
// Basic Input Validation Tests
describe('input validation', () => {
it('should return false for falsy values', async () => {
expect(await isActionDomainAllowed()).toBe(false);
expect(await isActionDomainAllowed(null)).toBe(false);
expect(await isActionDomainAllowed('')).toBe(false);
expect(await isActionDomainAllowed(undefined)).toBe(false);
});
it('should return false for non-string inputs', async () => {
expect(await isActionDomainAllowed(123)).toBe(false);
expect(await isActionDomainAllowed({})).toBe(false);
expect(await isActionDomainAllowed([])).toBe(false);
});
it('should return false for invalid domain formats', async () => {
getCustomConfig.mockResolvedValue({
actions: { allowedDomains: ['http://', 'https://'] },
});
expect(await isActionDomainAllowed('http://')).toBe(false);
expect(await isActionDomainAllowed('https://')).toBe(false);
});
});
// Configuration Tests
describe('configuration handling', () => {
it('should return true if customConfig is null', async () => {
getCustomConfig.mockResolvedValue(null);
expect(await isActionDomainAllowed('example.com')).toBe(true);
});
it('should return true if actions.allowedDomains is not defined', async () => {
getCustomConfig.mockResolvedValue({});
expect(await isActionDomainAllowed('example.com')).toBe(true);
});
it('should return true if allowedDomains is empty array', async () => {
getCustomConfig.mockResolvedValue({
actions: { allowedDomains: [] },
});
expect(await isActionDomainAllowed('example.com')).toBe(true);
});
});
// Domain Matching Tests
describe('domain matching', () => {
beforeEach(() => {
getCustomConfig.mockResolvedValue({
actions: {
allowedDomains: [
'example.com',
'*.subdomain.com',
'specific.domain.com',
'www.withprefix.com',
'swapi.dev',
],
},
});
});
it('should match exact domains', async () => {
expect(await isActionDomainAllowed('example.com')).toBe(true);
expect(await isActionDomainAllowed('other.com')).toBe(false);
expect(await isActionDomainAllowed('swapi.dev')).toBe(true);
});
it('should handle domains with www prefix', async () => {
expect(await isActionDomainAllowed('www.example.com')).toBe(true);
expect(await isActionDomainAllowed('www.withprefix.com')).toBe(true);
});
it('should handle full URLs', async () => {
expect(await isActionDomainAllowed('https://example.com')).toBe(true);
expect(await isActionDomainAllowed('http://example.com')).toBe(true);
expect(await isActionDomainAllowed('https://example.com/path')).toBe(true);
});
it('should handle wildcard subdomains', async () => {
expect(await isActionDomainAllowed('test.subdomain.com')).toBe(true);
expect(await isActionDomainAllowed('any.subdomain.com')).toBe(true);
expect(await isActionDomainAllowed('subdomain.com')).toBe(true);
});
it('should handle specific subdomains', async () => {
expect(await isActionDomainAllowed('specific.domain.com')).toBe(true);
expect(await isActionDomainAllowed('other.domain.com')).toBe(false);
});
});
// Edge Cases
describe('edge cases', () => {
beforeEach(() => {
getCustomConfig.mockResolvedValue({
actions: {
allowedDomains: ['example.com', '*.test.com'],
},
});
});
it('should handle domains with query parameters', async () => {
expect(await isActionDomainAllowed('example.com?param=value')).toBe(true);
});
it('should handle domains with ports', async () => {
expect(await isActionDomainAllowed('example.com:8080')).toBe(true);
});
it('should handle domains with trailing slashes', async () => {
expect(await isActionDomainAllowed('example.com/')).toBe(true);
});
it('should handle case insensitivity', async () => {
expect(await isActionDomainAllowed('EXAMPLE.COM')).toBe(true);
expect(await isActionDomainAllowed('Example.Com')).toBe(true);
});
it('should handle invalid entries in allowedDomains', async () => {
getCustomConfig.mockResolvedValue({
actions: {
allowedDomains: ['example.com', null, undefined, '', 'test.com'],
},
});
expect(await isActionDomainAllowed('example.com')).toBe(true);
expect(await isActionDomainAllowed('test.com')).toBe(true);
});
});
});

View file

@ -1,24 +0,0 @@
const { getCustomConfig } = require('~/server/services/Config');
async function isDomainAllowed(email) {
if (!email) {
return false;
}
const domain = email.split('@')[1];
if (!domain) {
return false;
}
const customConfig = await getCustomConfig();
if (!customConfig) {
return true;
} else if (!customConfig?.registration?.allowedDomains) {
return true;
}
return customConfig.registration.allowedDomains.includes(domain);
}
module.exports = isDomainAllowed;

View file

@ -1,60 +0,0 @@
const { getCustomConfig } = require('~/server/services/Config');
const isDomainAllowed = require('./isDomainAllowed');
jest.mock('~/server/services/Config', () => ({
getCustomConfig: jest.fn(),
}));
describe('isDomainAllowed', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should return false if email is falsy', async () => {
const email = '';
const result = await isDomainAllowed(email);
expect(result).toBe(false);
});
it('should return false if domain is not present in the email', async () => {
const email = 'test';
const result = await isDomainAllowed(email);
expect(result).toBe(false);
});
it('should return true if customConfig is not available', async () => {
const email = 'test@domain1.com';
getCustomConfig.mockResolvedValue(null);
const result = await isDomainAllowed(email);
expect(result).toBe(true);
});
it('should return true if allowedDomains is not defined in customConfig', async () => {
const email = 'test@domain1.com';
getCustomConfig.mockResolvedValue({});
const result = await isDomainAllowed(email);
expect(result).toBe(true);
});
it('should return true if domain is included in the allowedDomains', async () => {
const email = 'user@domain1.com';
getCustomConfig.mockResolvedValue({
registration: {
allowedDomains: ['domain1.com', 'domain2.com'],
},
});
const result = await isDomainAllowed(email);
expect(result).toBe(true);
});
it('should return false if domain is not included in the allowedDomains', async () => {
const email = 'user@domain3.com';
getCustomConfig.mockResolvedValue({
registration: {
allowedDomains: ['domain1.com', 'domain2.com'],
},
});
const result = await isDomainAllowed(email);
expect(result).toBe(false);
});
});

View file

@ -819,7 +819,7 @@
*/
/**
* @exports TAgentsEndpoint
* @exports TAssistantEndpoint
* @typedef {import('librechat-data-provider').TAssistantEndpoint} TAssistantEndpoint
* @memberof typedefs
*/

View file

@ -42,6 +42,7 @@ const errorMessages = {
[ErrorTypes.NO_USER_KEY]: 'com_error_no_user_key',
[ErrorTypes.INVALID_USER_KEY]: 'com_error_invalid_user_key',
[ErrorTypes.NO_BASE_URL]: 'com_error_no_base_url',
[ErrorTypes.INVALID_ACTION]: `com_error_${ErrorTypes.INVALID_ACTION}`,
[ErrorTypes.INVALID_REQUEST]: `com_error_${ErrorTypes.INVALID_REQUEST}`,
[ErrorTypes.NO_SYSTEM_MESSAGES]: `com_error_${ErrorTypes.NO_SYSTEM_MESSAGES}`,
[ErrorTypes.EXPIRED_USER_KEY]: (json: TExpiredKey, localize: LocalizeFunction) => {

View file

@ -30,6 +30,7 @@ export default {
'Resubmitting the AI message is not supported for this endpoint.',
com_error_invalid_request_error:
'The AI service rejected the request due to an error. This could be caused by an invalid API key or an improperly formatted request.',
com_error_invalid_action_error: 'Request denied: The specified action domain is not allowed.',
com_error_no_system_messages:
'The selected AI service or model does not support system messages. Try using prompts instead of custom instructions.',
com_error_invalid_user_key: 'Invalid key provided. Please provide a valid key and try again.',

2
package-lock.json generated
View file

@ -36153,7 +36153,7 @@
},
"packages/data-provider": {
"name": "librechat-data-provider",
"version": "0.7.61",
"version": "0.7.62",
"license": "ISC",
"dependencies": {
"@types/js-yaml": "^4.0.9",

View file

@ -1,6 +1,6 @@
{
"name": "librechat-data-provider",
"version": "0.7.61",
"version": "0.7.62",
"description": "data services for librechat apps",
"main": "dist/index.js",
"module": "dist/index.es.js",

View file

@ -201,15 +201,21 @@ class RequestExecutor {
oauth_client_secret,
} = metadata;
const isApiKey = api_key && type === AuthTypeEnum.ServiceHttp;
const isOAuth =
const isApiKey = api_key != null && api_key.length > 0 && type === AuthTypeEnum.ServiceHttp;
const isOAuth = !!(
oauth_client_id != null &&
oauth_client_id &&
oauth_client_secret != null &&
oauth_client_secret &&
type === AuthTypeEnum.OAuth &&
authorization_url != null &&
authorization_url &&
client_url != null &&
client_url &&
scope != null &&
scope &&
token_exchange_method;
token_exchange_method
);
if (isApiKey && authorization_type === AuthorizationTypeEnum.Basic) {
const basicToken = Buffer.from(api_key).toString('base64');
@ -219,11 +225,13 @@ class RequestExecutor {
} else if (
isApiKey &&
authorization_type === AuthorizationTypeEnum.Custom &&
custom_auth_header != null &&
custom_auth_header
) {
this.authHeaders[custom_auth_header] = api_key;
} else if (isOAuth) {
if (!this.authToken) {
const authToken = this.authToken ?? '';
if (!authToken) {
const tokenResponse = await axios.post(
client_url,
{

View file

@ -471,6 +471,11 @@ export const configSchema = z.object({
agents: true,
}),
fileStrategy: fileSourceSchema.default(FileSources.local),
actions: z
.object({
allowedDomains: z.array(z.string()).optional(),
})
.optional(),
registration: z
.object({
socialLogins: z.array(z.string()).optional(),
@ -962,6 +967,10 @@ export enum ErrorTypes {
* Invalid request error, API rejected request
*/
INVALID_REQUEST = 'invalid_request_error',
/**
* Invalid action request error, likely not on list of allowed domains
*/
INVALID_ACTION = 'invalid_action_error',
/**
* Invalid request error, API rejected request
*/