🧩 feat: Support Alternate API Keys for Plugins (#1760)

* refactor(DALL-E): retrieve env variables at runtime and not from memory

* feat(plugins): add alternate env variable handling to allow setting one api key for multiple plugins

* docs: update docs
This commit is contained in:
Danny Avila 2024-02-09 10:38:50 -05:00 committed by GitHub
parent 927ce5395b
commit 39caeb2027
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 328 additions and 113 deletions

View file

@ -30,6 +30,14 @@ const getOpenAIKey = async (options, user) => {
return openAIApiKey || (await getUserPluginAuthValue(user, 'OPENAI_API_KEY'));
};
/**
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
* Tools without required authentication or with valid authentication are considered valid.
*
* @param {Object} user The user object for whom to validate tool access.
* @param {Array<string>} tools An array of tool identifiers to validate. Defaults to an empty array.
* @returns {Promise<Array<string>>} A promise that resolves to an array of valid tool identifiers.
*/
const validateTools = async (user, tools = []) => {
try {
const validToolsSet = new Set(tools);
@ -37,16 +45,34 @@ const validateTools = async (user, tools = []) => {
validToolsSet.has(tool.pluginKey),
);
/**
* Validates the credentials for a given auth field or set of alternate auth fields for a tool.
* If valid admin or user authentication is found, the function returns early. Otherwise, it removes the tool from the set of valid tools.
*
* @param {string} authField The authentication field or fields (separated by "||" for alternates) to validate.
* @param {string} toolName The identifier of the tool being validated.
*/
const validateCredentials = async (authField, toolName) => {
const adminAuth = process.env[authField];
if (adminAuth && adminAuth.length > 0) {
return;
const fields = authField.split('||');
for (const field of fields) {
const adminAuth = process.env[field];
if (adminAuth && adminAuth.length > 0) {
return;
}
let userAuth = null;
try {
userAuth = await getUserPluginAuthValue(user, field);
} catch (err) {
if (field === fields[fields.length - 1] && !userAuth) {
throw err;
}
}
if (userAuth && userAuth.length > 0) {
return;
}
}
const userAuth = await getUserPluginAuthValue(user, authField);
if (userAuth && userAuth.length > 0) {
return;
}
validToolsSet.delete(toolName);
};
@ -63,20 +89,55 @@ const validateTools = async (user, tools = []) => {
return Array.from(validToolsSet.values());
} catch (err) {
logger.error('[validateTools] There was a problem validating tools', err);
throw new Error(err);
throw new Error('There was a problem validating tools');
}
};
const loadToolWithAuth = async (userId, authFields, ToolConstructor, options = {}) => {
/**
* Initializes a tool with authentication values for the given user, supporting alternate authentication fields.
* Authentication fields can have alternates separated by "||", and the first defined variable will be used.
*
* @param {string} userId The user ID for which the tool is being loaded.
* @param {Array<string>} authFields Array of strings representing the authentication fields. Supports alternate fields delimited by "||".
* @param {typeof import('langchain/tools').Tool} ToolConstructor The constructor function for the tool to be initialized.
* @param {Object} options Optional parameters to be passed to the tool constructor alongside authentication values.
* @returns {Function} An Async function that, when called, asynchronously initializes and returns an instance of the tool with authentication.
*/
const loadToolWithAuth = (userId, authFields, ToolConstructor, options = {}) => {
return async function () {
let authValues = {};
for (const authField of authFields) {
let authValue = process.env[authField];
if (!authValue) {
authValue = await getUserPluginAuthValue(userId, authField);
/**
* Finds the first non-empty value for the given authentication field, supporting alternate fields.
* @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||".
* @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found.
*/
const findAuthValue = async (fields) => {
for (const field of fields) {
let value = process.env[field];
if (value) {
return { authField: field, authValue: value };
}
try {
value = await getUserPluginAuthValue(userId, field);
} catch (err) {
if (field === fields[fields.length - 1] && !value) {
throw err;
}
}
if (value) {
return { authField: field, authValue: value };
}
}
return null;
};
for (let authField of authFields) {
const fields = authField.split('||');
const result = await findAuthValue(fields);
if (result) {
authValues[result.authField] = result.authValue;
}
authValues[authField] = authValue;
}
return new ToolConstructor({ ...options, ...authValues, userId });
@ -194,7 +255,7 @@ const loadTools = async ({
if (toolConstructors[tool]) {
const options = toolOptions[tool] || {};
const toolInstance = await loadToolWithAuth(
const toolInstance = loadToolWithAuth(
user,
toolAuthFields[tool],
toolConstructors[tool],
@ -250,6 +311,7 @@ const loadTools = async ({
};
module.exports = {
loadToolWithAuth,
validateTools,
loadTools,
};

View file

@ -4,26 +4,33 @@ const mockUser = {
findByIdAndDelete: jest.fn(),
};
var mockPluginService = {
const mockPluginService = {
updateUserPluginAuth: jest.fn(),
deleteUserPluginAuth: jest.fn(),
getUserPluginAuthValue: jest.fn(),
};
jest.mock('../../../../models/User', () => {
jest.mock('~/models/User', () => {
return function () {
return mockUser;
};
});
jest.mock('../../../../server/services/PluginService', () => mockPluginService);
jest.mock('~/server/services/PluginService', () => mockPluginService);
const User = require('../../../../models/User');
const { validateTools, loadTools } = require('./');
const PluginService = require('../../../../server/services/PluginService');
const { BaseChatModel } = require('langchain/chat_models/openai');
const { Calculator } = require('langchain/tools/calculator');
const { availableTools, OpenAICreateImage, GoogleSearchAPI, StructuredSD } = require('../');
const { BaseChatModel } = require('langchain/chat_models/openai');
const User = require('~/models/User');
const PluginService = require('~/server/services/PluginService');
const { validateTools, loadTools, loadToolWithAuth } = require('./handleTools');
const {
availableTools,
OpenAICreateImage,
GoogleSearchAPI,
StructuredSD,
WolframAlphaAPI,
} = require('../');
describe('Tool Handlers', () => {
let fakeUser;
@ -44,7 +51,10 @@ describe('Tool Handlers', () => {
});
mockPluginService.updateUserPluginAuth.mockImplementation(
(userId, authField, _pluginKey, credential) => {
userAuthValues[`${userId}-${authField}`] = credential;
const fields = authField.split('||');
fields.forEach((field) => {
userAuthValues[`${userId}-${field}`] = credential;
});
},
);
@ -134,6 +144,18 @@ describe('Tool Handlers', () => {
loadTool2 = toolFunctions[sampleTools[1]];
loadTool3 = toolFunctions[sampleTools[2]];
});
let originalEnv;
beforeEach(() => {
originalEnv = process.env;
process.env = { ...originalEnv };
});
afterEach(() => {
process.env = originalEnv;
});
it('returns the expected load functions for requested tools', async () => {
expect(loadTool1).toBeDefined();
expect(loadTool2).toBeDefined();
@ -150,6 +172,86 @@ describe('Tool Handlers', () => {
expect(authTool).toBeInstanceOf(ToolClass);
expect(tool).toBeInstanceOf(ToolClass2);
});
it('should initialize an authenticated tool with primary auth field', async () => {
process.env.DALLE2_API_KEY = 'mocked_api_key';
const initToolFunction = loadToolWithAuth(
'userId',
['DALLE2_API_KEY||DALLE_API_KEY'],
ToolClass,
);
const authTool = await initToolFunction();
expect(authTool).toBeInstanceOf(ToolClass);
expect(mockPluginService.getUserPluginAuthValue).not.toHaveBeenCalled();
});
it('should initialize an authenticated tool with alternate auth field when primary is missing', async () => {
delete process.env.DALLE2_API_KEY; // Ensure the primary key is not set
process.env.DALLE_API_KEY = 'mocked_alternate_api_key';
const initToolFunction = loadToolWithAuth(
'userId',
['DALLE2_API_KEY||DALLE_API_KEY'],
ToolClass,
);
const authTool = await initToolFunction();
expect(authTool).toBeInstanceOf(ToolClass);
expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledTimes(1);
expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledWith(
'userId',
'DALLE2_API_KEY',
);
});
it('should fallback to getUserPluginAuthValue when env vars are missing', async () => {
mockPluginService.updateUserPluginAuth('userId', 'DALLE_API_KEY', 'dalle', 'mocked_api_key');
const initToolFunction = loadToolWithAuth(
'userId',
['DALLE2_API_KEY||DALLE_API_KEY'],
ToolClass,
);
const authTool = await initToolFunction();
expect(authTool).toBeInstanceOf(ToolClass);
expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledTimes(2);
});
it('should initialize an authenticated tool with singular auth field', async () => {
process.env.WOLFRAM_APP_ID = 'mocked_app_id';
const initToolFunction = loadToolWithAuth('userId', ['WOLFRAM_APP_ID'], WolframAlphaAPI);
const authTool = await initToolFunction();
expect(authTool).toBeInstanceOf(WolframAlphaAPI);
expect(mockPluginService.getUserPluginAuthValue).not.toHaveBeenCalled();
});
it('should initialize an authenticated tool when env var is set', async () => {
process.env.WOLFRAM_APP_ID = 'mocked_app_id';
const initToolFunction = loadToolWithAuth('userId', ['WOLFRAM_APP_ID'], WolframAlphaAPI);
const authTool = await initToolFunction();
expect(authTool).toBeInstanceOf(WolframAlphaAPI);
expect(mockPluginService.getUserPluginAuthValue).not.toHaveBeenCalledWith(
'userId',
'WOLFRAM_APP_ID',
);
});
it('should fallback to getUserPluginAuthValue when singular env var is missing', async () => {
delete process.env.WOLFRAM_APP_ID; // Ensure the environment variable is not set
mockPluginService.getUserPluginAuthValue.mockResolvedValue('mocked_user_auth_value');
const initToolFunction = loadToolWithAuth('userId', ['WOLFRAM_APP_ID'], WolframAlphaAPI);
const authTool = await initToolFunction();
expect(authTool).toBeInstanceOf(WolframAlphaAPI);
expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledTimes(1);
expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledWith(
'userId',
'WOLFRAM_APP_ID',
);
});
it('should throw an error for an unauthenticated tool', async () => {
try {
await loadTool2();

View file

@ -1,17 +1,48 @@
const { getUserPluginAuthValue } = require('../../../../server/services/PluginService');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { availableTools } = require('../');
const loadToolSuite = async ({ pluginKey, tools, user, options }) => {
/**
* Loads a suite of tools with authentication values for a given user, supporting alternate authentication fields.
* Authentication fields can have alternates separated by "||", and the first defined variable will be used.
*
* @param {Object} params Parameters for loading the tool suite.
* @param {string} params.pluginKey Key identifying the plugin whose tools are to be loaded.
* @param {Array<Function>} params.tools Array of tool constructor functions.
* @param {Object} params.user User object for whom the tools are being loaded.
* @param {Object} [params.options={}] Optional parameters to be passed to each tool constructor.
* @returns {Promise<Array>} A promise that resolves to an array of instantiated tools.
*/
const loadToolSuite = async ({ pluginKey, tools, user, options = {} }) => {
const authConfig = availableTools.find((tool) => tool.pluginKey === pluginKey).authConfig;
const suite = [];
const authValues = {};
for (const auth of authConfig) {
let authValue = process.env[auth.authField];
if (!authValue) {
authValue = await getUserPluginAuthValue(user, auth.authField);
const findAuthValue = async (authField) => {
const fields = authField.split('||');
for (const field of fields) {
let value = process.env[field];
if (value) {
return value;
}
try {
value = await getUserPluginAuthValue(user, field);
if (value) {
return value;
}
} catch (err) {
console.error(`Error fetching plugin auth value for ${field}: ${err.message}`);
}
}
return null;
};
for (const auth of authConfig) {
const authValue = await findAuthValue(auth.authField);
if (authValue !== null) {
authValues[auth.authField] = authValue;
} else {
console.warn(`No auth value found for ${auth.authField}`);
}
authValues[auth.authField] = authValue;
}
for (const tool of tools) {