refactor(tools): restructure tool dir

This commit is contained in:
Daniel Avila 2023-06-14 13:05:33 -04:00 committed by Danny Avila
parent 71d812403e
commit 1b3215c55d
5 changed files with 48 additions and 26 deletions

View file

@ -12,7 +12,8 @@ const { CallbackManager } = require('langchain/callbacks');
const { HumanChatMessage, AIChatMessage } = require('langchain/schema'); const { HumanChatMessage, AIChatMessage } = require('langchain/schema');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents/'); const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents/');
const { getMessages, saveMessage, saveConvo } = require('../../models'); const { getMessages, saveMessage, saveConvo } = require('../../models');
const { loadTools, SelfReflectionTool } = require('./tools'); const { loadTools } = require('./tools/util');
const { SelfReflectionTool } = require('./tools/');
const { const {
instructions, instructions,
imageInstructions, imageInstructions,

View file

@ -1,10 +1,19 @@
const GoogleSearchAPI = require('./GoogleSearch');
const HttpRequestTool = require('./HttpRequestTool');
const AIPluginTool = require('./AIPluginTool');
const OpenAICreateImage = require('./DALL-E');
const StructuredSD = require('./structured/StableDiffusion');
const StableDiffusionAPI = require('./StableDiffusion');
const WolframAlphaAPI = require('./Wolfram');
const SelfReflectionTool = require('./SelfReflection'); const SelfReflectionTool = require('./SelfReflection');
const availableTools = require('./manifest.json');
const { validateTools, loadTools } = require('./handleTools');
module.exports = { module.exports = {
validateTools, GoogleSearchAPI,
loadTools, HttpRequestTool,
availableTools, AIPluginTool,
OpenAICreateImage,
StructuredSD,
StableDiffusionAPI,
WolframAlphaAPI,
SelfReflectionTool SelfReflectionTool
}; }

View file

@ -1,3 +1,4 @@
const { getUserPluginAuthValue } = require('../../../../server/services/PluginService');
const { OpenAIEmbeddings } = require('langchain/embeddings/openai'); const { OpenAIEmbeddings } = require('langchain/embeddings/openai');
const { ZapierToolKit } = require('langchain/agents'); const { ZapierToolKit } = require('langchain/agents');
const { const {
@ -7,14 +8,16 @@ const {
const { ChatOpenAI } = require('langchain/chat_models/openai'); const { ChatOpenAI } = require('langchain/chat_models/openai');
const { Calculator } = require('langchain/tools/calculator'); const { Calculator } = require('langchain/tools/calculator');
const { WebBrowser } = require('langchain/tools/webbrowser'); const { WebBrowser } = require('langchain/tools/webbrowser');
const GoogleSearchAPI = require('./GoogleSearch'); const {
const HttpRequestTool = require('./HttpRequestTool'); AIPluginTool,
const AIPluginTool = require('./AIPluginTool'); GoogleSearchAPI,
const OpenAICreateImage = require('./DALL-E'); WolframAlphaAPI,
const StableDiffusionAPI = require('./StableDiffusion'); HttpRequestTool,
const WolframAlphaAPI = require('./Wolfram'); OpenAICreateImage,
const availableTools = require('./manifest.json'); StableDiffusionAPI,
const { getUserPluginAuthValue } = require('../../../server/services/PluginService'); StructuredSD,
} = require('../');
const availableTools = require('../manifest.json');
const validateTools = async (user, tools = []) => { const validateTools = async (user, tools = []) => {
try { try {
@ -70,12 +73,13 @@ const loadToolWithAuth = async (user, authFields, ToolConstructor, options = {})
}; };
const loadTools = async ({ user, model, tools = [], options = {} }) => { const loadTools = async ({ user, model, tools = [], options = {} }) => {
const { functions } = options;
const toolConstructors = { const toolConstructors = {
calculator: Calculator, calculator: Calculator,
google: GoogleSearchAPI, google: GoogleSearchAPI,
wolfram: WolframAlphaAPI, wolfram: WolframAlphaAPI,
'dall-e': OpenAICreateImage, 'dall-e': OpenAICreateImage,
'stable-diffusion': StableDiffusionAPI 'stable-diffusion': functions ? StructuredSD : StableDiffusionAPI
}; };
const customConstructors = { const customConstructors = {
@ -109,9 +113,10 @@ const loadTools = async ({ user, model, tools = [], options = {} }) => {
return [ return [
new HttpRequestTool(), new HttpRequestTool(),
await AIPluginTool.fromPluginUrl( await AIPluginTool.fromPluginUrl(
"https://www.klarna.com/.well-known/ai-plugin.json", new ChatOpenAI({ openAIApiKey: options.openAIApiKey, temperature: 0 }) 'https://www.klarna.com/.well-known/ai-plugin.json',
), new ChatOpenAI({ openAIApiKey: options.openAIApiKey, temperature: 0 })
] )
];
} }
}; };

View file

@ -11,21 +11,20 @@ var mockPluginService = {
}; };
jest.mock('../../../models/User', () => { jest.mock('../../../../models/User', () => {
return function() { return function() {
return mockUser; return mockUser;
}; };
}); });
jest.mock('../../../server/services/PluginService', () => mockPluginService); jest.mock('../../../../server/services/PluginService', () => mockPluginService);
const User = require('../../../models/User'); const User = require('../../../../models/User');
const { validateTools, loadTools, availableTools } = require('./index'); const { validateTools, loadTools, availableTools } = require('./');
const PluginService = require('../../../server/services/PluginService'); const PluginService = require('../../../../server/services/PluginService');
const { BaseChatModel } = require('langchain/chat_models/openai'); const { BaseChatModel } = require('langchain/chat_models/openai');
const { Calculator } = require('langchain/tools/calculator'); const { Calculator } = require('langchain/tools/calculator');
const OpenAICreateImage = require('./DALL-E'); const { OpenAICreateImage, GoogleSearchAPI } = require('../');
const GoogleSearchAPI = require('./GoogleSearch');
describe('Tool Handlers', () => { describe('Tool Handlers', () => {
let fakeUser; let fakeUser;

View file

@ -0,0 +1,8 @@
const availableTools = require('../manifest.json');
const { validateTools, loadTools } = require('./handleTools');
module.exports = {
validateTools,
loadTools,
availableTools
};