From 980262984837d2d09a98f4bf1d8c628a5f4dea69 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 27 Feb 2025 12:59:51 -0500 Subject: [PATCH 01/27] =?UTF-8?q?=F0=9F=9A=80=20feat:=20Agent=20Cache=20To?= =?UTF-8?q?kens=20&=20Anthropic=20Reasoning=20Support=20(#6098)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: handling of top_k and top_p parameters for Claude-3.7 models (allowed without reasoning) * feat: bump @librechat/agents for Anthropic Reasoning support * fix: update reasoning handling for OpenRouter integration * fix: enhance agent token spending logic to include cache creation and read details * fix: update logic for thinking status in ContentParts component * refactor: improve agent title handling * chore: bump @librechat/agents to version 2.1.7 for parallel tool calling for Google models --- api/app/clients/AnthropicClient.js | 17 +++-- api/app/clients/OpenAIClient.js | 6 ++ api/app/clients/specs/AnthropicClient.test.js | 49 +++++++++++++++ api/package.json | 2 +- api/server/controllers/agents/client.js | 62 ++++++++++++++----- api/server/services/Endpoints/agents/title.js | 11 +++- .../services/Endpoints/anthropic/llm.js | 15 ++--- .../services/Endpoints/anthropic/llm.spec.js | 41 ++++++++++++ api/server/services/Endpoints/openAI/llm.js | 12 +++- .../Chat/Messages/Content/ContentParts.tsx | 4 +- package-lock.json | 8 +-- 11 files changed, 187 insertions(+), 40 deletions(-) diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index a2ab752bc2..19f4a3930a 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -746,15 +746,6 @@ class AnthropicClient extends BaseClient { metadata, }; - if (!/claude-3[-.]7/.test(model)) { - if (top_p !== undefined) { - requestOptions.top_p = top_p; - } - if (top_k !== undefined) { - requestOptions.top_k = top_k; - } - } - if (this.useMessages) { requestOptions.messages = payload; requestOptions.max_tokens = @@ -769,6 +760,14 @@ class AnthropicClient extends BaseClient { thinkingBudget: this.options.thinkingBudget, }); + if (!/claude-3[-.]7/.test(model)) { + requestOptions.top_p = top_p; + requestOptions.top_k = top_k; + } else if (requestOptions.thinking == null) { + requestOptions.topP = top_p; + requestOptions.topK = top_k; + } + if (this.systemMessage && this.supportsCacheControl === true) { requestOptions.system = [ { diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 8d0bce25d2..4bc2d66ca0 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1309,6 +1309,12 @@ ${convo} modelOptions.include_reasoning = true; reasoningKey = 'reasoning'; } + if (this.useOpenRouter && modelOptions.reasoning_effort != null) { + modelOptions.reasoning = { + effort: modelOptions.reasoning_effort, + }; + delete modelOptions.reasoning_effort; + } this.streamHandler = new SplitStreamHandler({ reasoningKey, diff --git a/api/app/clients/specs/AnthropicClient.test.js b/api/app/clients/specs/AnthropicClient.test.js index b565e6d188..223f3038c0 100644 --- a/api/app/clients/specs/AnthropicClient.test.js +++ b/api/app/clients/specs/AnthropicClient.test.js @@ -680,4 +680,53 @@ describe('AnthropicClient', () => { expect(capturedOptions).not.toHaveProperty('top_p'); }); }); + + it('should include top_k and top_p parameters for Claude-3.7 models when thinking is explicitly disabled', async () => { + const client = new AnthropicClient('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + thinking: false, + }); + + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + expect(capturedOptions).toHaveProperty('topK', 10); + expect(capturedOptions).toHaveProperty('topP', 0.9); + + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + thinking: false, + }); + + await client.sendCompletion(payload, {}); + + expect(capturedOptions).toHaveProperty('topK', 10); + expect(capturedOptions).toHaveProperty('topP', 0.9); + }); }); diff --git a/api/package.json b/api/package.json index e386394acb..2d83fddcbd 100644 --- a/api/package.json +++ b/api/package.json @@ -45,7 +45,7 @@ "@langchain/google-genai": "^0.1.9", "@langchain/google-vertexai": "^0.2.0", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.1.3", + "@librechat/agents": "^2.1.7", "@waylaidwanderer/fetch-event-source": "^3.0.1", "axios": "1.7.8", "bcryptjs": "^2.4.3", diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index bff7dc65eb..99d64bb9a6 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -27,10 +27,10 @@ const { formatContentStrings, createContextHandlers, } = require('~/app/clients/prompts'); -const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const Tokenizer = require('~/server/services/Tokenizer'); -const { spendTokens } = require('~/models/spendTokens'); const BaseClient = require('~/app/clients/BaseClient'); const { createRun } = require('./run'); const { logger } = require('~/config'); @@ -380,15 +380,34 @@ class AgentClient extends BaseClient { if (!collectedUsage || !collectedUsage.length) { return; } - const input_tokens = collectedUsage[0]?.input_tokens || 0; + const input_tokens = + (collectedUsage[0]?.input_tokens || 0) + + (Number(collectedUsage[0]?.input_token_details?.cache_creation) || 0) + + (Number(collectedUsage[0]?.input_token_details?.cache_read) || 0); let output_tokens = 0; let previousTokens = input_tokens; // Start with original input for (let i = 0; i < collectedUsage.length; i++) { const usage = collectedUsage[i]; + if (!usage) { + continue; + } + + const cache_creation = Number(usage.input_token_details?.cache_creation) || 0; + const cache_read = Number(usage.input_token_details?.cache_read) || 0; + + const txMetadata = { + context, + conversationId: this.conversationId, + user: this.user ?? this.options.req.user?.id, + endpointTokenConfig: this.options.endpointTokenConfig, + model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, + }; + if (i > 0) { // Count new tokens generated (input_tokens minus previous accumulated tokens) - output_tokens += (Number(usage.input_tokens) || 0) - previousTokens; + output_tokens += + (Number(usage.input_tokens) || 0) + cache_creation + cache_read - previousTokens; } // Add this message's output tokens @@ -396,16 +415,26 @@ class AgentClient extends BaseClient { // Update previousTokens to include this message's output previousTokens += Number(usage.output_tokens) || 0; - spendTokens( - { - context, - conversationId: this.conversationId, - user: this.user ?? this.options.req.user?.id, - endpointTokenConfig: this.options.endpointTokenConfig, - model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, - }, - { promptTokens: usage.input_tokens, completionTokens: usage.output_tokens }, - ).catch((err) => { + + if (cache_creation > 0 || cache_read > 0) { + spendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending structured tokens', + err, + ); + }); + } + spendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }).catch((err) => { logger.error( '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens', err, @@ -792,7 +821,10 @@ class AgentClient extends BaseClient { throw new Error('Run not initialized'); } const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); - const clientOptions = {}; + /** @type {import('@librechat/agents').ClientOptions} */ + const clientOptions = { + maxTokens: 75, + }; const providerConfig = this.options.req.app.locals[this.options.agent.provider]; if ( providerConfig && diff --git a/api/server/services/Endpoints/agents/title.js b/api/server/services/Endpoints/agents/title.js index 56fd28668d..f25746582e 100644 --- a/api/server/services/Endpoints/agents/title.js +++ b/api/server/services/Endpoints/agents/title.js @@ -20,10 +20,19 @@ const addTitle = async (req, { text, response, client }) => { const titleCache = getLogStores(CacheKeys.GEN_TITLE); const key = `${req.user.id}-${response.conversationId}`; + const responseText = + response?.content && Array.isArray(response?.content) + ? response.content.reduce((acc, block) => { + if (block?.type === 'text') { + return acc + block.text; + } + return acc; + }, '') + : (response?.content ?? response?.text ?? ''); const title = await client.titleConvo({ text, - responseText: response?.text ?? '', + responseText, conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Endpoints/anthropic/llm.js b/api/server/services/Endpoints/anthropic/llm.js index 186444cec8..9f20b8e61d 100644 --- a/api/server/services/Endpoints/anthropic/llm.js +++ b/api/server/services/Endpoints/anthropic/llm.js @@ -1,6 +1,6 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { anthropicSettings, removeNullishValues } = require('librechat-data-provider'); -const { checkPromptCacheSupport, getClaudeHeaders } = require('./helpers'); +const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers'); /** * Generates configuration options for creating an Anthropic language model (LLM) instance. @@ -49,13 +49,14 @@ function getLLMConfig(apiKey, options = {}) { clientOptions: {}, }; + requestOptions = configureReasoning(requestOptions, systemOptions); + if (!/claude-3[-.]7/.test(mergedOptions.model)) { - if (mergedOptions.topP !== undefined) { - requestOptions.topP = mergedOptions.topP; - } - if (mergedOptions.topK !== undefined) { - requestOptions.topK = mergedOptions.topK; - } + requestOptions.topP = mergedOptions.topP; + requestOptions.topK = mergedOptions.topK; + } else if (requestOptions.thinking == null) { + requestOptions.topP = mergedOptions.topP; + requestOptions.topK = mergedOptions.topK; } const supportsCacheControl = diff --git a/api/server/services/Endpoints/anthropic/llm.spec.js b/api/server/services/Endpoints/anthropic/llm.spec.js index a1dc6a44b6..9c453efb92 100644 --- a/api/server/services/Endpoints/anthropic/llm.spec.js +++ b/api/server/services/Endpoints/anthropic/llm.spec.js @@ -109,4 +109,45 @@ describe('getLLMConfig', () => { // Just verifying that the promptCache setting is processed expect(result.llmConfig).toBeDefined(); }); + + it('should include topK and topP for Claude-3.7 models when thinking is not enabled', () => { + // Test with thinking explicitly set to null/undefined + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + + // Test with thinking explicitly set to false + const result2 = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result2.llmConfig).toHaveProperty('topK', 10); + expect(result2.llmConfig).toHaveProperty('topP', 0.9); + + // Test with decimal notation as well + const result3 = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3.7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result3.llmConfig).toHaveProperty('topK', 10); + expect(result3.llmConfig).toHaveProperty('topP', 0.9); + }); }); diff --git a/api/server/services/Endpoints/openAI/llm.js b/api/server/services/Endpoints/openAI/llm.js index c12f835f2f..0fa899b4a3 100644 --- a/api/server/services/Endpoints/openAI/llm.js +++ b/api/server/services/Endpoints/openAI/llm.js @@ -29,7 +29,6 @@ function getLLMConfig(apiKey, options = {}) { const { modelOptions = {}, reverseProxyUrl, - useOpenRouter, defaultQuery, headers, proxy, @@ -56,9 +55,11 @@ function getLLMConfig(apiKey, options = {}) { }); } + let useOpenRouter; /** @type {OpenAIClientOptions['configuration']} */ const configOptions = {}; - if (useOpenRouter || (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter))) { + if (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) { + useOpenRouter = true; llmConfig.include_reasoning = true; configOptions.baseURL = reverseProxyUrl; configOptions.defaultHeaders = Object.assign( @@ -118,6 +119,13 @@ function getLLMConfig(apiKey, options = {}) { llmConfig.organization = process.env.OPENAI_ORGANIZATION; } + if (useOpenRouter && llmConfig.reasoning_effort != null) { + llmConfig.reasoning = { + effort: llmConfig.reasoning_effort, + }; + delete llmConfig.reasoning_effort; + } + return { /** @type {OpenAIClientOptions} */ llmConfig, diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index b997060c61..ddf08976eb 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -109,7 +109,9 @@ const ContentParts = memo( return val; }) } - label={isSubmitting ? localize('com_ui_thinking') : localize('com_ui_thoughts')} + label={ + isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts') + } /> )} diff --git a/package-lock.json b/package-lock.json index c513481bf4..e4e9d06680 100644 --- a/package-lock.json +++ b/package-lock.json @@ -61,7 +61,7 @@ "@langchain/google-genai": "^0.1.9", "@langchain/google-vertexai": "^0.2.0", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.1.3", + "@librechat/agents": "^2.1.7", "@waylaidwanderer/fetch-event-source": "^3.0.1", "axios": "1.7.8", "bcryptjs": "^2.4.3", @@ -15984,9 +15984,9 @@ } }, "node_modules/@librechat/agents": { - "version": "2.1.3", - "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-2.1.3.tgz", - "integrity": "sha512-4pPkLpjhA3DDiZQOULcrpbdQaOBC4JuUMdcVTUyYBHcA63SJT3olstmRQkGKNvoXLFLeQyJ0jkOqkEpzLJzk/g==", + "version": "2.1.7", + "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-2.1.7.tgz", + "integrity": "sha512-/+AvxH75K0dSSUeHqT8jPZCcqcQUWdB56g9ls7ho0Nw9vdxfezBhF/hXnOk5oORHeEXlGEKNE6YPyjAhCmNIOg==", "dependencies": { "@aws-crypto/sha256-js": "^5.2.0", "@aws-sdk/credential-provider-node": "^3.613.0", From 2293cd667e3052b44cc27c3efc18f951859d4a9a Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 28 Feb 2025 12:19:21 -0500 Subject: [PATCH 02/27] =?UTF-8?q?=F0=9F=9A=80=20feat:=20GPT-4.5,=20Anthrop?= =?UTF-8?q?ic=20Tool=20Header,=20and=20OpenAPI=20Ref=20Resolution=20(#6118?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 refactor: Update settings to use 'as const' for improved type safety and make gpt-4o-mini default model (cheapest) * 📖 docs: Update README to reflect support for GPT-4.5 in image analysis feature * 🔧 refactor: Update model handling to use default settings and improve encoding logic * 🔧 refactor: Enhance model version extraction logic for improved compatibility with future GPT and omni models * feat: GPT-4.5 tx/token update, vision support * fix: $ref resolution logic in OpenAPI handling * feat: add new 'anthropic-beta' header for Claude 3.7 to include token-efficient tools; ref: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use --- .env.example | 2 +- README.md | 2 +- api/app/clients/OpenAIClient.js | 10 +- api/models/tx.js | 3 + api/models/tx.spec.js | 10 ++ .../services/Endpoints/anthropic/helpers.js | 3 +- api/utils/tokens.js | 1 + api/utils/tokens.spec.js | 10 ++ package-lock.json | 2 +- packages/data-provider/package.json | 2 +- packages/data-provider/specs/actions.spec.ts | 135 +++++++++++++- packages/data-provider/src/actions.ts | 77 +++++--- packages/data-provider/src/config.ts | 8 +- packages/data-provider/src/parsers.ts | 50 ++++-- packages/data-provider/src/schemas.ts | 170 +++++++++--------- 15 files changed, 337 insertions(+), 148 deletions(-) diff --git a/.env.example b/.env.example index 94a6d80d88..a1ab8e8485 100644 --- a/.env.example +++ b/.env.example @@ -175,7 +175,7 @@ GOOGLE_KEY=user_provided #============# OPENAI_API_KEY=user_provided -# OPENAI_MODELS=o1,o1-mini,o1-preview,gpt-4o,chatgpt-4o-latest,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k +# OPENAI_MODELS=o1,o1-mini,o1-preview,gpt-4o,gpt-4.5-preview,chatgpt-4o-latest,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k DEBUG_OPENAI=false diff --git a/README.md b/README.md index 2e662ac262..f58b1999e5 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ - [Fork Messages & Conversations](https://www.librechat.ai/docs/features/fork) for Advanced Context control - 💬 **Multimodal & File Interactions**: - - Upload and analyze images with Claude 3, GPT-4o, o1, Llama-Vision, and Gemini 📸 + - Upload and analyze images with Claude 3, GPT-4.5, GPT-4o, o1, Llama-Vision, and Gemini 📸 - Chat with Files using Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, & Google 🗃️ - 🌎 **Multilingual UI**: diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 4bc2d66ca0..ab851e254c 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -298,7 +298,9 @@ class OpenAIClient extends BaseClient { } getEncoding() { - return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base'; + return this.modelOptions?.model && /gpt-4[^-\s]/.test(this.modelOptions.model) + ? 'o200k_base' + : 'cl100k_base'; } /** @@ -605,7 +607,7 @@ class OpenAIClient extends BaseClient { } initializeLLM({ - model = 'gpt-4o-mini', + model = openAISettings.model.default, modelName, temperature = 0.2, max_tokens, @@ -706,7 +708,7 @@ class OpenAIClient extends BaseClient { const { OPENAI_TITLE_MODEL } = process.env ?? {}; - let model = this.options.titleModel ?? OPENAI_TITLE_MODEL ?? 'gpt-4o-mini'; + let model = this.options.titleModel ?? OPENAI_TITLE_MODEL ?? openAISettings.model.default; if (model === Constants.CURRENT_MODEL) { model = this.modelOptions.model; } @@ -899,7 +901,7 @@ ${convo} let prompt; // TODO: remove the gpt fallback and make it specific to endpoint - const { OPENAI_SUMMARY_MODEL = 'gpt-4o-mini' } = process.env ?? {}; + const { OPENAI_SUMMARY_MODEL = openAISettings.model.default } = process.env ?? {}; let model = this.options.summaryModel ?? OPENAI_SUMMARY_MODEL; if (model === Constants.CURRENT_MODEL) { model = this.modelOptions.model; diff --git a/api/models/tx.js b/api/models/tx.js index 82ae9fb034..b534e7edc9 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -79,6 +79,7 @@ const tokenValues = Object.assign( 'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-preview': { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 }, + 'gpt-4.5': { prompt: 75, completion: 150 }, 'gpt-4o-mini': { prompt: 0.15, completion: 0.6 }, 'gpt-4o': { prompt: 2.5, completion: 10 }, 'gpt-4o-2024-05-13': { prompt: 5, completion: 15 }, @@ -167,6 +168,8 @@ const getValueKey = (model, endpoint) => { return 'o1-mini'; } else if (modelName.includes('o1')) { return 'o1'; + } else if (modelName.includes('gpt-4.5')) { + return 'gpt-4.5'; } else if (modelName.includes('gpt-4o-2024-05-13')) { return 'gpt-4o-2024-05-13'; } else if (modelName.includes('gpt-4o-mini')) { diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index 9cec82165f..b04eacc9f3 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -50,6 +50,16 @@ describe('getValueKey', () => { expect(getValueKey('gpt-4-0125')).toBe('gpt-4-1106'); }); + it('should return "gpt-4.5" for model type of "gpt-4.5"', () => { + expect(getValueKey('gpt-4.5-preview')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-2024-08-06')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-2024-08-06-0718')).toBe('gpt-4.5'); + expect(getValueKey('openai/gpt-4.5')).toBe('gpt-4.5'); + expect(getValueKey('openai/gpt-4.5-2024-08-06')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-turbo')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-0125')).toBe('gpt-4.5'); + }); + it('should return "gpt-4o" for model type of "gpt-4o"', () => { expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o'); expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o'); diff --git a/api/server/services/Endpoints/anthropic/helpers.js b/api/server/services/Endpoints/anthropic/helpers.js index c7425f6ff1..04e4efc61c 100644 --- a/api/server/services/Endpoints/anthropic/helpers.js +++ b/api/server/services/Endpoints/anthropic/helpers.js @@ -48,7 +48,8 @@ function getClaudeHeaders(model, supportsCacheControl) { }; } else if (/claude-3[-.]7/.test(model)) { return { - 'anthropic-beta': 'output-128k-2025-02-19,prompt-caching-2024-07-31', + 'anthropic-beta': + 'token-efficient-tools-2025-02-19,output-128k-2025-02-19,prompt-caching-2024-07-31', }; } else { return { diff --git a/api/utils/tokens.js b/api/utils/tokens.js index 34c6df4cf4..8edfb0a31c 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -13,6 +13,7 @@ const openAIModels = { 'gpt-4-32k-0613': 32758, // -10 from max 'gpt-4-1106': 127500, // -500 from max 'gpt-4-0125': 127500, // -500 from max + 'gpt-4.5': 127500, // -500 from max 'gpt-4o': 127500, // -500 from max 'gpt-4o-mini': 127500, // -500 from max 'gpt-4o-2024-05-13': 127500, // -500 from max diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index 5a963c385f..d4dbb30498 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -103,6 +103,16 @@ describe('getModelMaxTokens', () => { ); }); + test('should return correct tokens for gpt-4.5 matches', () => { + expect(getModelMaxTokens('gpt-4.5')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-4.5']); + expect(getModelMaxTokens('gpt-4.5-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.5'], + ); + expect(getModelMaxTokens('openai/gpt-4.5-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.5'], + ); + }); + test('should return correct tokens for Anthropic models', () => { const models = [ 'claude-2.1', diff --git a/package-lock.json b/package-lock.json index e4e9d06680..ec7025ac7d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -41798,7 +41798,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.7.6993", + "version": "0.7.6994", "license": "ISC", "dependencies": { "axios": "^1.7.7", diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index da2859e9c8..27ea28e435 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.7.6993", + "version": "0.7.6994", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/specs/actions.spec.ts b/packages/data-provider/specs/actions.spec.ts index 1bd2b5494a..10bf95a23e 100644 --- a/packages/data-provider/specs/actions.spec.ts +++ b/packages/data-provider/specs/actions.spec.ts @@ -585,21 +585,99 @@ describe('resolveRef', () => { openapiSpec.paths['/ai.chatgpt.render-flowchart']?.post ?.requestBody as OpenAPIV3.RequestBodyObject ).content['application/json'].schema; - expect(flowchartRequestRef).toBeDefined(); - const resolvedFlowchartRequest = resolveRef( - flowchartRequestRef as OpenAPIV3.RequestBodyObject, - openapiSpec.components, - ); - expect(resolvedFlowchartRequest).toBeDefined(); - expect(resolvedFlowchartRequest.type).toBe('object'); - const properties = resolvedFlowchartRequest.properties as FlowchartSchema; - expect(properties).toBeDefined(); + expect(flowchartRequestRef).toBeDefined(); + + const resolvedSchemaObject = resolveRef( + flowchartRequestRef as OpenAPIV3.ReferenceObject, + openapiSpec.components, + ) as OpenAPIV3.SchemaObject; + + expect(resolvedSchemaObject).toBeDefined(); + expect(resolvedSchemaObject.type).toBe('object'); + expect(resolvedSchemaObject.properties).toBeDefined(); + + const properties = resolvedSchemaObject.properties as FlowchartSchema; expect(properties.mermaid).toBeDefined(); expect(properties.mermaid.type).toBe('string'); }); }); +describe('resolveRef general cases', () => { + const spec = { + openapi: '3.0.0', + info: { title: 'TestSpec', version: '1.0.0' }, + paths: {}, + components: { + schemas: { + TestSchema: { type: 'string' }, + }, + parameters: { + TestParam: { + name: 'myParam', + in: 'query', + required: false, + schema: { $ref: '#/components/schemas/TestSchema' }, + }, + }, + requestBodies: { + TestRequestBody: { + content: { + 'application/json': { + schema: { $ref: '#/components/schemas/TestSchema' }, + }, + }, + }, + }, + }, + } satisfies OpenAPIV3.Document; + + it('resolves schema refs correctly', () => { + const schemaRef: OpenAPIV3.ReferenceObject = { $ref: '#/components/schemas/TestSchema' }; + const resolvedSchema = resolveRef( + schemaRef, + spec.components, + ); + expect(resolvedSchema.type).toEqual('string'); + }); + + it('resolves parameter refs correctly, then schema within parameter', () => { + const paramRef: OpenAPIV3.ReferenceObject = { $ref: '#/components/parameters/TestParam' }; + const resolvedParam = resolveRef( + paramRef, + spec.components, + ); + expect(resolvedParam.name).toEqual('myParam'); + expect(resolvedParam.in).toEqual('query'); + expect(resolvedParam.required).toBe(false); + + const paramSchema = resolveRef( + resolvedParam.schema as OpenAPIV3.ReferenceObject, + spec.components, + ); + expect(paramSchema.type).toEqual('string'); + }); + + it('resolves requestBody refs correctly, then schema within requestBody', () => { + const requestBodyRef: OpenAPIV3.ReferenceObject = { + $ref: '#/components/requestBodies/TestRequestBody', + }; + const resolvedRequestBody = resolveRef( + requestBodyRef, + spec.components, + ); + + expect(resolvedRequestBody.content['application/json']).toBeDefined(); + + const schemaInRequestBody = resolveRef( + resolvedRequestBody.content['application/json'].schema as OpenAPIV3.ReferenceObject, + spec.components, + ); + + expect(schemaInRequestBody.type).toEqual('string'); + }); +}); + describe('openapiToFunction', () => { it('converts OpenAPI spec to function signatures and request builders', () => { const { functionSignatures, requestBuilders } = openapiToFunction(getWeatherOpenapiSpec); @@ -1095,4 +1173,43 @@ describe('createURL', () => { }); }); }); + + describe('openapiToFunction parameter refs resolution', () => { + const weatherSpec = { + openapi: '3.0.0', + info: { title: 'Weather', version: '1.0.0' }, + servers: [{ url: 'https://api.weather.gov' }], + paths: { + '/points/{point}': { + get: { + operationId: 'getPoint', + parameters: [{ $ref: '#/components/parameters/PathPoint' }], + responses: { '200': { description: 'ok' } }, + }, + }, + }, + components: { + parameters: { + PathPoint: { + name: 'point', + in: 'path', + required: true, + schema: { type: 'string', pattern: '^(-?\\d+(?:\\.\\d+)?),(-?\\d+(?:\\.\\d+)?)$' }, + }, + }, + }, + } satisfies OpenAPIV3.Document; + + it('correctly resolves $ref for parameters', () => { + const { functionSignatures } = openapiToFunction(weatherSpec, true); + const func = functionSignatures.find((sig) => sig.name === 'getPoint'); + expect(func).toBeDefined(); + expect(func?.parameters.properties).toHaveProperty('point'); + expect(func?.parameters.required).toContain('point'); + + const paramSchema = func?.parameters.properties['point'] as OpenAPIV3.SchemaObject; + expect(paramSchema.type).toEqual('string'); + expect(paramSchema.pattern).toEqual('^(-?\\d+(?:\\.\\d+)?),(-?\\d+(?:\\.\\d+)?)$'); + }); + }); }); diff --git a/packages/data-provider/src/actions.ts b/packages/data-provider/src/actions.ts index 5533e6832c..8f8d5f603d 100644 --- a/packages/data-provider/src/actions.ts +++ b/packages/data-provider/src/actions.ts @@ -22,8 +22,8 @@ export type ParametersSchema = { export type OpenAPISchema = OpenAPIV3.SchemaObject & ParametersSchema & { - items?: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject; -}; + items?: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject; + }; export type ApiKeyCredentials = { api_key: string; @@ -43,8 +43,8 @@ export type Credentials = ApiKeyCredentials | OAuthCredentials; type MediaTypeObject = | undefined | { - [media: string]: OpenAPIV3.MediaTypeObject | undefined; -}; + [media: string]: OpenAPIV3.MediaTypeObject | undefined; + }; type RequestBodyObject = Omit & { content: MediaTypeObject; @@ -358,19 +358,29 @@ export class ActionRequest { } } -export function resolveRef( - schema: OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject | RequestBodyObject, - components?: OpenAPIV3.ComponentsObject, -): OpenAPIV3.SchemaObject { - if ('$ref' in schema && components) { - const refPath = schema.$ref.replace(/^#\/components\/schemas\//, ''); - const resolvedSchema = components.schemas?.[refPath]; - if (!resolvedSchema) { - throw new Error(`Reference ${schema.$ref} not found`); +export function resolveRef< + T extends + | OpenAPIV3.ReferenceObject + | OpenAPIV3.SchemaObject + | OpenAPIV3.ParameterObject + | OpenAPIV3.RequestBodyObject, +>(obj: T, components?: OpenAPIV3.ComponentsObject): Exclude { + if ('$ref' in obj && components) { + const refPath = obj.$ref.replace(/^#\/components\//, '').split('/'); + + let resolved: unknown = components as Record; + for (const segment of refPath) { + if (typeof resolved === 'object' && resolved !== null && segment in resolved) { + resolved = (resolved as Record)[segment]; + } else { + throw new Error(`Could not resolve reference: ${obj.$ref}`); + } } - return resolveRef(resolvedSchema, components); + + return resolveRef(resolved as typeof obj, components) as Exclude; } - return schema as OpenAPIV3.SchemaObject; + + return obj as Exclude; } function sanitizeOperationId(input: string) { @@ -399,7 +409,7 @@ export function openapiToFunction( const operationObj = operation as OpenAPIV3.OperationObject & { 'x-openai-isConsequential'?: boolean; } & { - 'x-strict'?: boolean + 'x-strict'?: boolean; }; // Operation ID is used as the function name @@ -415,15 +425,25 @@ export function openapiToFunction( }; if (operationObj.parameters) { - for (const param of operationObj.parameters) { - const paramObj = param as OpenAPIV3.ParameterObject; - const resolvedSchema = resolveRef( - { ...paramObj.schema } as OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject, + for (const param of operationObj.parameters ?? []) { + const resolvedParam = resolveRef( + param, openapiSpec.components, - ); - parametersSchema.properties[paramObj.name] = resolvedSchema; - if (paramObj.required === true) { - parametersSchema.required.push(paramObj.name); + ) as OpenAPIV3.ParameterObject; + + const paramName = resolvedParam.name; + if (!paramName || !resolvedParam.schema) { + continue; + } + + const paramSchema = resolveRef( + resolvedParam.schema, + openapiSpec.components, + ) as OpenAPIV3.SchemaObject; + + parametersSchema.properties[paramName] = paramSchema; + if (resolvedParam.required) { + parametersSchema.required.push(paramName); } } } @@ -446,7 +466,12 @@ export function openapiToFunction( } } - const functionSignature = new FunctionSignature(operationId, description, parametersSchema, isStrict); + const functionSignature = new FunctionSignature( + operationId, + description, + parametersSchema, + isStrict, + ); functionSignatures.push(functionSignature); const actionRequest = new ActionRequest( @@ -544,4 +569,4 @@ export function validateAndParseOpenAPISpec(specString: string): ValidationResul console.error(error); return { status: false, message: 'Error parsing OpenAPI spec.' }; } -} \ No newline at end of file +} diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index d5923645e0..5ce56b6d73 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -15,6 +15,7 @@ export const defaultRetrievalModels = [ 'o1-preview', 'o1-mini-2024-09-12', 'o1-mini', + 'o3-mini', 'chatgpt-4o-latest', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', @@ -651,6 +652,8 @@ export const alternateName = { const sharedOpenAIModels = [ 'gpt-4o-mini', 'gpt-4o', + 'gpt-4.5-preview', + 'gpt-4.5-preview-2025-02-27', 'gpt-3.5-turbo', 'gpt-3.5-turbo-0125', 'gpt-4-turbo', @@ -723,7 +726,7 @@ export const bedrockModels = [ export const defaultModels = { [EModelEndpoint.azureAssistants]: sharedOpenAIModels, - [EModelEndpoint.assistants]: ['chatgpt-4o-latest', ...sharedOpenAIModels], + [EModelEndpoint.assistants]: [...sharedOpenAIModels, 'chatgpt-4o-latest'], [EModelEndpoint.agents]: sharedOpenAIModels, // TODO: Add agent models (agentsModels) [EModelEndpoint.google]: [ // Shared Google Models between Vertex AI & Gen AI @@ -742,8 +745,8 @@ export const defaultModels = { ], [EModelEndpoint.anthropic]: sharedAnthropicModels, [EModelEndpoint.openAI]: [ - 'chatgpt-4o-latest', ...sharedOpenAIModels, + 'chatgpt-4o-latest', 'gpt-4-vision-preview', 'gpt-3.5-turbo-instruct-0914', 'gpt-3.5-turbo-instruct', @@ -808,6 +811,7 @@ export const supportsBalanceCheck = { }; export const visionModels = [ + 'gpt-4.5', 'gpt-4o', 'gpt-4o-mini', 'o1', diff --git a/packages/data-provider/src/parsers.ts b/packages/data-provider/src/parsers.ts index 8ec18d5617..58d6fa3712 100644 --- a/packages/data-provider/src/parsers.ts +++ b/packages/data-provider/src/parsers.ts @@ -128,7 +128,6 @@ export const envVarRegex = /^\${(.+)}$/; export function extractEnvVariable(value: string) { const envVarMatch = value.match(envVarRegex); if (envVarMatch) { - // eslint-disable-next-line @typescript-eslint/strict-boolean-expressions return process.env[envVarMatch[1]] || value; } return value; @@ -211,6 +210,29 @@ export const parseConvo = ({ return convo; }; +/** Match GPT followed by digit, optional decimal, and optional suffix + * + * Examples: gpt-4, gpt-4o, gpt-4.5, gpt-5a, etc. */ +const extractGPTVersion = (modelStr: string): string => { + const gptMatch = modelStr.match(/gpt-(\d+(?:\.\d+)?)([a-z])?/i); + if (gptMatch) { + const version = gptMatch[1]; + const suffix = gptMatch[2] || ''; + return `GPT-${version}${suffix}`; + } + return ''; +}; + +/** Match omni models (o1, o3, etc.), "o" followed by a digit, possibly with decimal */ +const extractOmniVersion = (modelStr: string): string => { + const omniMatch = modelStr.match(/\bo(\d+(?:\.\d+)?)\b/i); + if (omniMatch) { + const version = omniMatch[1]; + return `o${version}`; + } + return ''; +}; + export const getResponseSender = (endpointOption: t.TEndpointOption): string => { const { model: _m, @@ -238,18 +260,13 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string => return chatGptLabel; } else if (modelLabel) { return modelLabel; - } else if (model && /\bo1\b/i.test(model)) { - return 'o1'; - } else if (model && /\bo3\b/i.test(model)) { - return 'o3'; - } else if (model && model.includes('gpt-3')) { - return 'GPT-3.5'; - } else if (model && model.includes('gpt-4o')) { - return 'GPT-4o'; - } else if (model && model.includes('gpt-4')) { - return 'GPT-4'; + } else if (model && extractOmniVersion(model)) { + return extractOmniVersion(model); } else if (model && model.includes('mistral')) { return 'Mistral'; + } else if (model && model.includes('gpt-')) { + const gptVersion = extractGPTVersion(model); + return gptVersion || 'GPT'; } return (alternateName[endpoint] as string | undefined) ?? 'ChatGPT'; } @@ -279,14 +296,13 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string => return modelLabel; } else if (chatGptLabel) { return chatGptLabel; + } else if (model && extractOmniVersion(model)) { + return extractOmniVersion(model); } else if (model && model.includes('mistral')) { return 'Mistral'; - } else if (model && model.includes('gpt-3')) { - return 'GPT-3.5'; - } else if (model && model.includes('gpt-4o')) { - return 'GPT-4o'; - } else if (model && model.includes('gpt-4')) { - return 'GPT-4'; + } else if (model && model.includes('gpt-')) { + const gptVersion = extractGPTVersion(model); + return gptVersion || 'GPT'; } else if (modelDisplayLabel) { return modelDisplayLabel; } diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index cee0230386..533d6ffc37 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -179,34 +179,34 @@ export const isImageVisionTool = (tool: FunctionTool | FunctionToolCall) => export const openAISettings = { model: { - default: 'gpt-4o', + default: 'gpt-4o-mini' as const, }, temperature: { - min: 0, - max: 2, - step: 0.01, - default: 1, + min: 0 as const, + max: 2 as const, + step: 0.01 as const, + default: 1 as const, }, top_p: { - min: 0, - max: 1, - step: 0.01, - default: 1, + min: 0 as const, + max: 1 as const, + step: 0.01 as const, + default: 1 as const, }, presence_penalty: { - min: 0, - max: 2, - step: 0.01, - default: 0, + min: 0 as const, + max: 2 as const, + step: 0.01 as const, + default: 0 as const, }, frequency_penalty: { - min: 0, - max: 2, - step: 0.01, - default: 0, + min: 0 as const, + max: 2 as const, + step: 0.01 as const, + default: 0 as const, }, resendFiles: { - default: true, + default: true as const, }, maxContextTokens: { default: undefined, @@ -215,72 +215,72 @@ export const openAISettings = { default: undefined, }, imageDetail: { - default: ImageDetail.auto, - min: 0, - max: 2, - step: 1, + default: ImageDetail.auto as const, + min: 0 as const, + max: 2 as const, + step: 1 as const, }, }; export const googleSettings = { model: { - default: 'gemini-1.5-flash-latest', + default: 'gemini-1.5-flash-latest' as const, }, maxOutputTokens: { - min: 1, - max: 8192, - step: 1, - default: 8192, + min: 1 as const, + max: 8192 as const, + step: 1 as const, + default: 8192 as const, }, temperature: { - min: 0, - max: 2, - step: 0.01, - default: 1, + min: 0 as const, + max: 2 as const, + step: 0.01 as const, + default: 1 as const, }, topP: { - min: 0, - max: 1, - step: 0.01, - default: 0.95, + min: 0 as const, + max: 1 as const, + step: 0.01 as const, + default: 0.95 as const, }, topK: { - min: 1, - max: 40, - step: 1, - default: 40, + min: 1 as const, + max: 40 as const, + step: 1 as const, + default: 40 as const, }, }; -const ANTHROPIC_MAX_OUTPUT = 128000; -const DEFAULT_MAX_OUTPUT = 8192; -const LEGACY_ANTHROPIC_MAX_OUTPUT = 4096; +const ANTHROPIC_MAX_OUTPUT = 128000 as const; +const DEFAULT_MAX_OUTPUT = 8192 as const; +const LEGACY_ANTHROPIC_MAX_OUTPUT = 4096 as const; export const anthropicSettings = { model: { - default: 'claude-3-5-sonnet-latest', + default: 'claude-3-5-sonnet-latest' as const, }, temperature: { - min: 0, - max: 1, - step: 0.01, - default: 1, + min: 0 as const, + max: 1 as const, + step: 0.01 as const, + default: 1 as const, }, promptCache: { - default: true, + default: true as const, }, thinking: { - default: true, + default: true as const, }, thinkingBudget: { - min: 1024, - step: 100, - max: 200000, - default: 2000, + min: 1024 as const, + step: 100 as const, + max: 200000 as const, + default: 2000 as const, }, maxOutputTokens: { - min: 1, + min: 1 as const, max: ANTHROPIC_MAX_OUTPUT, - step: 1, + step: 1 as const, default: DEFAULT_MAX_OUTPUT, reset: (modelName: string) => { if (/claude-3[-.]5-sonnet/.test(modelName) || /claude-3[-.]7/.test(modelName)) { @@ -301,28 +301,28 @@ export const anthropicSettings = { }, }, topP: { - min: 0, - max: 1, - step: 0.01, - default: 0.7, + min: 0 as const, + max: 1 as const, + step: 0.01 as const, + default: 0.7 as const, }, topK: { - min: 1, - max: 40, - step: 1, - default: 5, + min: 1 as const, + max: 40 as const, + step: 1 as const, + default: 5 as const, }, resendFiles: { - default: true, + default: true as const, }, maxContextTokens: { default: undefined, }, legacy: { maxOutputTokens: { - min: 1, + min: 1 as const, max: LEGACY_ANTHROPIC_MAX_OUTPUT, - step: 1, + step: 1 as const, default: LEGACY_ANTHROPIC_MAX_OUTPUT, }, }, @@ -330,34 +330,34 @@ export const anthropicSettings = { export const agentsSettings = { model: { - default: 'gpt-3.5-turbo-test', + default: 'gpt-3.5-turbo-test' as const, }, temperature: { - min: 0, - max: 1, - step: 0.01, - default: 1, + min: 0 as const, + max: 1 as const, + step: 0.01 as const, + default: 1 as const, }, top_p: { - min: 0, - max: 1, - step: 0.01, - default: 1, + min: 0 as const, + max: 1 as const, + step: 0.01 as const, + default: 1 as const, }, presence_penalty: { - min: 0, - max: 2, - step: 0.01, - default: 0, + min: 0 as const, + max: 2 as const, + step: 0.01 as const, + default: 0 as const, }, frequency_penalty: { - min: 0, - max: 2, - step: 0.01, - default: 0, + min: 0 as const, + max: 2 as const, + step: 0.01 as const, + default: 0 as const, }, resendFiles: { - default: true, + default: true as const, }, maxContextTokens: { default: undefined, @@ -366,7 +366,7 @@ export const agentsSettings = { default: undefined, }, imageDetail: { - default: ImageDetail.auto, + default: ImageDetail.auto as const, }, }; From 7f6b32ff04d67ce0ad6f78c891a06f9e5a90cd96 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 1 Mar 2025 07:51:12 -0500 Subject: [PATCH 03/27] =?UTF-8?q?=F0=9F=96=BC=EF=B8=8F=20refactor:=20Enhan?= =?UTF-8?q?ce=20Env=20Extraction=20&=20Agent=20Image=20Handling=20(#6131)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: use new image output format for agents using DALL-E tools * refactor: Enhance image fetching with proxy support and adjust logging placement in DALL-E 3 integration * refactor: Enhance StableDiffusionAPI to support agent-specific return values and display message for generated images * refactor: Add unit test execution for librechat-mcp in backend review workflow * refactor: Update environment variable extraction logic, export from serpate module to avoid circular refs, and remove deprecated tests * refactor: Add unit tests for environment variable extraction and enhance StdioOptionsSchema to process env variables --- .github/workflows/backend-review.yml | 5 +- api/app/clients/tools/structured/DALLE3.js | 36 ++++- .../tools/structured/StableDiffusion.js | 38 +++++- api/server/controllers/agents/callbacks.js | 31 +---- package-lock.json | 2 +- packages/data-provider/package.json | 2 +- packages/data-provider/specs/mcp.spec.ts | 52 +++++++ packages/data-provider/specs/parsers.spec.ts | 48 ------- packages/data-provider/specs/utils.spec.ts | 129 ++++++++++++++++++ packages/data-provider/src/azure.ts | 3 +- packages/data-provider/src/index.ts | 1 + packages/data-provider/src/mcp.ts | 17 ++- packages/data-provider/src/parsers.ts | 12 +- packages/data-provider/src/utils.ts | 44 ++++++ 14 files changed, 321 insertions(+), 99 deletions(-) create mode 100644 packages/data-provider/specs/mcp.spec.ts delete mode 100644 packages/data-provider/specs/parsers.spec.ts create mode 100644 packages/data-provider/specs/utils.spec.ts create mode 100644 packages/data-provider/src/utils.ts diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index 5bc3d3b2db..8469fc366d 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -61,4 +61,7 @@ jobs: run: cd api && npm run test:ci - name: Run librechat-data-provider unit tests - run: cd packages/data-provider && npm run test:ci \ No newline at end of file + run: cd packages/data-provider && npm run test:ci + + - name: Run librechat-mcp unit tests + run: cd packages/mcp && npm run test:ci \ No newline at end of file diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index b604ad4ea4..81200e3a61 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -1,14 +1,17 @@ const { z } = require('zod'); const path = require('path'); const OpenAI = require('openai'); +const fetch = require('node-fetch'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('@langchain/core/tools'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); const extractBaseURL = require('~/utils/extractBaseURL'); const { logger } = require('~/config'); +const displayMessage = + 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; class DALLE3 extends Tool { constructor(fields = {}) { super(); @@ -114,10 +117,7 @@ class DALLE3 extends Tool { if (this.isAgent === true && typeof value === 'string') { return [value, {}]; } else if (this.isAgent === true && typeof value === 'object') { - return [ - 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.', - value, - ]; + return [displayMessage, value]; } return value; @@ -160,6 +160,32 @@ Error Message: ${error.message}`); ); } + if (this.isAgent) { + let fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(theImageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/jpeg;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } + const imageBasename = getImageBasename(theImageUrl); const imageExt = path.extname(imageBasename); diff --git a/api/app/clients/tools/structured/StableDiffusion.js b/api/app/clients/tools/structured/StableDiffusion.js index 6309da35d8..25a9e0abd3 100644 --- a/api/app/clients/tools/structured/StableDiffusion.js +++ b/api/app/clients/tools/structured/StableDiffusion.js @@ -6,10 +6,13 @@ const axios = require('axios'); const sharp = require('sharp'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('@langchain/core/tools'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); const paths = require('~/config/paths'); const { logger } = require('~/config'); +const displayMessage = + 'Stable Diffusion displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + class StableDiffusionAPI extends Tool { constructor(fields) { super(); @@ -21,6 +24,8 @@ class StableDiffusionAPI extends Tool { this.override = fields.override ?? false; /** @type {boolean} Necessary for output to contain all image metadata. */ this.returnMetadata = fields.returnMetadata ?? false; + /** @type {boolean} */ + this.isAgent = fields.isAgent; if (fields.uploadImageBuffer) { /** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */ this.uploadImageBuffer = fields.uploadImageBuffer.bind(this); @@ -66,6 +71,16 @@ class StableDiffusionAPI extends Tool { return `![generated image](/${imageUrl})`; } + returnValue(value) { + if (this.isAgent === true && typeof value === 'string') { + return [value, {}]; + } else if (this.isAgent === true && typeof value === 'object') { + return [displayMessage, value]; + } + + return value; + } + getServerURL() { const url = process.env.SD_WEBUI_URL || ''; if (!url && !this.override) { @@ -113,6 +128,25 @@ class StableDiffusionAPI extends Tool { } try { + if (this.isAgent) { + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${image}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } + const buffer = Buffer.from(image.split(',', 1)[0], 'base64'); if (this.returnMetadata && this.uploadImageBuffer && this.req) { const file = await this.uploadImageBuffer({ @@ -154,7 +188,7 @@ class StableDiffusionAPI extends Tool { logger.error('[StableDiffusion] Error while saving the image:', error); } - return this.result; + return this.returnValue(this.result); } } diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index f43c9db5ba..45beefe7e6 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,4 +1,5 @@ -const { Tools, StepTypes, imageGenTools, FileContext } = require('librechat-data-provider'); +const { nanoid } = require('nanoid'); +const { Tools, StepTypes, FileContext } = require('librechat-data-provider'); const { EnvVar, Providers, @@ -242,32 +243,6 @@ function createToolEndCallback({ req, res, artifactPromises }) { return; } - if (imageGenTools.has(output.name)) { - artifactPromises.push( - (async () => { - const fileMetadata = Object.assign(output.artifact, { - messageId: metadata.run_id, - toolCallId: output.tool_call_id, - conversationId: metadata.thread_id, - }); - if (!res.headersSent) { - return fileMetadata; - } - - if (!fileMetadata) { - return null; - } - - res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`); - return fileMetadata; - })().catch((error) => { - logger.error('Error processing code output:', error); - return null; - }), - ); - return; - } - if (output.artifact.content) { /** @type {FormattedContent[]} */ const content = output.artifact.content; @@ -278,7 +253,7 @@ function createToolEndCallback({ req, res, artifactPromises }) { const { url } = part.image_url; artifactPromises.push( (async () => { - const filename = `${output.tool_call_id}-image-${new Date().getTime()}`; + const filename = `${output.name}_${output.tool_call_id}_img_${nanoid()}`; const file = await saveBase64Image(url, { req, filename, diff --git a/package-lock.json b/package-lock.json index ec7025ac7d..533add4d7d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -41798,7 +41798,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.7.6994", + "version": "0.7.6995", "license": "ISC", "dependencies": { "axios": "^1.7.7", diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index 27ea28e435..542d6cd74c 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.7.6994", + "version": "0.7.6995", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/specs/mcp.spec.ts b/packages/data-provider/specs/mcp.spec.ts new file mode 100644 index 0000000000..b72df6d4c2 --- /dev/null +++ b/packages/data-provider/specs/mcp.spec.ts @@ -0,0 +1,52 @@ +import { StdioOptionsSchema } from '../src/mcp'; + +describe('Environment Variable Extraction (MCP)', () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { + ...originalEnv, + TEST_API_KEY: 'test-api-key-value', + ANOTHER_SECRET: 'another-secret-value', + }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + describe('StdioOptionsSchema', () => { + it('should transform environment variables in the env field', () => { + const options = { + command: 'node', + args: ['server.js'], + env: { + API_KEY: '${TEST_API_KEY}', + ANOTHER_KEY: '${ANOTHER_SECRET}', + PLAIN_VALUE: 'plain-value', + NON_EXISTENT: '${NON_EXISTENT_VAR}', + }, + }; + + const result = StdioOptionsSchema.parse(options); + + expect(result.env).toEqual({ + API_KEY: 'test-api-key-value', + ANOTHER_KEY: 'another-secret-value', + PLAIN_VALUE: 'plain-value', + NON_EXISTENT: '${NON_EXISTENT_VAR}', + }); + }); + + it('should handle undefined env field', () => { + const options = { + command: 'node', + args: ['server.js'], + }; + + const result = StdioOptionsSchema.parse(options); + + expect(result.env).toBeUndefined(); + }); + }); +}); diff --git a/packages/data-provider/specs/parsers.spec.ts b/packages/data-provider/specs/parsers.spec.ts deleted file mode 100644 index e9ec9b20a4..0000000000 --- a/packages/data-provider/specs/parsers.spec.ts +++ /dev/null @@ -1,48 +0,0 @@ -import { extractEnvVariable } from '../src/parsers'; - -describe('extractEnvVariable', () => { - const originalEnv = process.env; - - beforeEach(() => { - jest.resetModules(); - process.env = { ...originalEnv }; - }); - - afterAll(() => { - process.env = originalEnv; - }); - - test('should return the value of the environment variable', () => { - process.env.TEST_VAR = 'test_value'; - expect(extractEnvVariable('${TEST_VAR}')).toBe('test_value'); - }); - - test('should return the original string if the envrionment variable is not defined correctly', () => { - process.env.TEST_VAR = 'test_value'; - expect(extractEnvVariable('${ TEST_VAR }')).toBe('${ TEST_VAR }'); - }); - - test('should return the original string if environment variable is not set', () => { - expect(extractEnvVariable('${NON_EXISTENT_VAR}')).toBe('${NON_EXISTENT_VAR}'); - }); - - test('should return the original string if it does not contain an environment variable', () => { - expect(extractEnvVariable('some_string')).toBe('some_string'); - }); - - test('should handle empty strings', () => { - expect(extractEnvVariable('')).toBe(''); - }); - - test('should handle strings without variable format', () => { - expect(extractEnvVariable('no_var_here')).toBe('no_var_here'); - }); - - test('should not process multiple variable formats', () => { - process.env.FIRST_VAR = 'first'; - process.env.SECOND_VAR = 'second'; - expect(extractEnvVariable('${FIRST_VAR} and ${SECOND_VAR}')).toBe( - '${FIRST_VAR} and ${SECOND_VAR}', - ); - }); -}); diff --git a/packages/data-provider/specs/utils.spec.ts b/packages/data-provider/specs/utils.spec.ts new file mode 100644 index 0000000000..01c403f4e8 --- /dev/null +++ b/packages/data-provider/specs/utils.spec.ts @@ -0,0 +1,129 @@ +import { extractEnvVariable } from '../src/utils'; + +describe('Environment Variable Extraction', () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { + ...originalEnv, + TEST_API_KEY: 'test-api-key-value', + ANOTHER_SECRET: 'another-secret-value', + }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + describe('extractEnvVariable (original tests)', () => { + test('should return the value of the environment variable', () => { + process.env.TEST_VAR = 'test_value'; + expect(extractEnvVariable('${TEST_VAR}')).toBe('test_value'); + }); + + test('should return the original string if the envrionment variable is not defined correctly', () => { + process.env.TEST_VAR = 'test_value'; + expect(extractEnvVariable('${ TEST_VAR }')).toBe('${ TEST_VAR }'); + }); + + test('should return the original string if environment variable is not set', () => { + expect(extractEnvVariable('${NON_EXISTENT_VAR}')).toBe('${NON_EXISTENT_VAR}'); + }); + + test('should return the original string if it does not contain an environment variable', () => { + expect(extractEnvVariable('some_string')).toBe('some_string'); + }); + + test('should handle empty strings', () => { + expect(extractEnvVariable('')).toBe(''); + }); + + test('should handle strings without variable format', () => { + expect(extractEnvVariable('no_var_here')).toBe('no_var_here'); + }); + + /** No longer the expected behavior; keeping for reference */ + test.skip('should not process multiple variable formats', () => { + process.env.FIRST_VAR = 'first'; + process.env.SECOND_VAR = 'second'; + expect(extractEnvVariable('${FIRST_VAR} and ${SECOND_VAR}')).toBe( + '${FIRST_VAR} and ${SECOND_VAR}', + ); + }); + }); + + describe('extractEnvVariable function', () => { + it('should extract environment variables from exact matches', () => { + expect(extractEnvVariable('${TEST_API_KEY}')).toBe('test-api-key-value'); + expect(extractEnvVariable('${ANOTHER_SECRET}')).toBe('another-secret-value'); + }); + + it('should extract environment variables from strings with prefixes', () => { + expect(extractEnvVariable('prefix-${TEST_API_KEY}')).toBe('prefix-test-api-key-value'); + }); + + it('should extract environment variables from strings with suffixes', () => { + expect(extractEnvVariable('${TEST_API_KEY}-suffix')).toBe('test-api-key-value-suffix'); + }); + + it('should extract environment variables from strings with both prefixes and suffixes', () => { + expect(extractEnvVariable('prefix-${TEST_API_KEY}-suffix')).toBe( + 'prefix-test-api-key-value-suffix', + ); + }); + + it('should not match invalid patterns', () => { + expect(extractEnvVariable('$TEST_API_KEY')).toBe('$TEST_API_KEY'); + expect(extractEnvVariable('{TEST_API_KEY}')).toBe('{TEST_API_KEY}'); + expect(extractEnvVariable('TEST_API_KEY')).toBe('TEST_API_KEY'); + }); + }); + + describe('extractEnvVariable', () => { + it('should extract environment variable values', () => { + expect(extractEnvVariable('${TEST_API_KEY}')).toBe('test-api-key-value'); + expect(extractEnvVariable('${ANOTHER_SECRET}')).toBe('another-secret-value'); + }); + + it('should return the original string if environment variable is not found', () => { + expect(extractEnvVariable('${NON_EXISTENT_VAR}')).toBe('${NON_EXISTENT_VAR}'); + }); + + it('should return the original string if no environment variable pattern is found', () => { + expect(extractEnvVariable('plain-string')).toBe('plain-string'); + }); + }); + + describe('extractEnvVariable space trimming', () => { + beforeEach(() => { + process.env.HELLO = 'world'; + process.env.USER = 'testuser'; + }); + + it('should extract the value when string contains only an environment variable with surrounding whitespace', () => { + expect(extractEnvVariable(' ${HELLO} ')).toBe('world'); + expect(extractEnvVariable(' ${HELLO} ')).toBe('world'); + expect(extractEnvVariable('\t${HELLO}\n')).toBe('world'); + }); + + it('should preserve content when variable is part of a larger string', () => { + expect(extractEnvVariable('Hello ${USER}!')).toBe('Hello testuser!'); + expect(extractEnvVariable(' Hello ${USER}! ')).toBe('Hello testuser!'); + }); + + it('should not handle multiple variables', () => { + expect(extractEnvVariable('${HELLO} ${USER}')).toBe('${HELLO} ${USER}'); + expect(extractEnvVariable(' ${HELLO} ${USER} ')).toBe('${HELLO} ${USER}'); + }); + + it('should handle undefined variables', () => { + expect(extractEnvVariable(' ${UNDEFINED_VAR} ')).toBe('${UNDEFINED_VAR}'); + }); + + it('should handle mixed content correctly', () => { + expect(extractEnvVariable('Welcome, ${USER}!\nYour message: ${HELLO}')).toBe( + 'Welcome, testuser!\nYour message: world', + ); + }); + }); +}); diff --git a/packages/data-provider/src/azure.ts b/packages/data-provider/src/azure.ts index f5948820be..17188ec551 100644 --- a/packages/data-provider/src/azure.ts +++ b/packages/data-provider/src/azure.ts @@ -6,8 +6,9 @@ import type { TValidatedAzureConfig, TAzureConfigValidationResult, } from '../src/config'; -import { errorsToString, extractEnvVariable, envVarRegex } from '../src/parsers'; +import { extractEnvVariable, envVarRegex } from '../src/utils'; import { azureGroupConfigsSchema } from '../src/config'; +import { errorsToString } from '../src/parsers'; export const deprecatedAzureVariables = [ /* "related to" precedes description text */ diff --git a/packages/data-provider/src/index.ts b/packages/data-provider/src/index.ts index 739ece7330..90b396001b 100644 --- a/packages/data-provider/src/index.ts +++ b/packages/data-provider/src/index.ts @@ -31,5 +31,6 @@ export { default as request } from './request'; export { dataService }; import * as dataService from './data-service'; /* general helpers */ +export * from './utils'; export * from './actions'; export { default as createPayload } from './createPayload'; diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index bb8a55f161..2328a0071e 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -1,4 +1,5 @@ import { z } from 'zod'; +import { extractEnvVariable } from './utils'; const BaseOptionsSchema = z.object({ iconPath: z.string().optional(), @@ -18,8 +19,22 @@ export const StdioOptionsSchema = BaseOptionsSchema.extend({ * The environment to use when spawning the process. * * If not specified, the result of getDefaultEnvironment() will be used. + * Environment variables can be referenced using ${VAR_NAME} syntax. */ - env: z.record(z.string(), z.string()).optional(), + env: z + .record(z.string(), z.string()) + .optional() + .transform((env) => { + if (!env) { + return env; + } + + const processedEnv: Record = {}; + for (const [key, value] of Object.entries(env)) { + processedEnv[key] = extractEnvVariable(value); + } + return processedEnv; + }), /** * How to handle stderr of the child process. This matches the semantics of Node's `child_process.spawn`. * diff --git a/packages/data-provider/src/parsers.ts b/packages/data-provider/src/parsers.ts index 58d6fa3712..10a23a542b 100644 --- a/packages/data-provider/src/parsers.ts +++ b/packages/data-provider/src/parsers.ts @@ -19,6 +19,7 @@ import { compactAssistantSchema, } from './schemas'; import { bedrockInputSchema } from './bedrock'; +import { extractEnvVariable } from './utils'; import { alternateName } from './config'; type EndpointSchema = @@ -122,17 +123,6 @@ export function errorsToString(errors: ZodIssue[]) { .join(' '); } -export const envVarRegex = /^\${(.+)}$/; - -/** Extracts the value of an environment variable from a string. */ -export function extractEnvVariable(value: string) { - const envVarMatch = value.match(envVarRegex); - if (envVarMatch) { - return process.env[envVarMatch[1]] || value; - } - return value; -} - /** Resolves header values to env variables if detected */ export function resolveHeaders(headers: Record | undefined) { const resolvedHeaders = { ...(headers ?? {}) }; diff --git a/packages/data-provider/src/utils.ts b/packages/data-provider/src/utils.ts new file mode 100644 index 0000000000..de41a93dc6 --- /dev/null +++ b/packages/data-provider/src/utils.ts @@ -0,0 +1,44 @@ +export const envVarRegex = /^\${(.+)}$/; + +/** Extracts the value of an environment variable from a string. */ +export function extractEnvVariable(value: string) { + if (!value) { + return value; + } + + // Trim the input + const trimmed = value.trim(); + + // Special case: if it's just a single environment variable + const singleMatch = trimmed.match(envVarRegex); + if (singleMatch) { + const varName = singleMatch[1]; + return process.env[varName] || trimmed; + } + + // For multiple variables, process them using a regex loop + const regex = /\${([^}]+)}/g; + let result = trimmed; + + // First collect all matches and their positions + const matches = []; + let match; + while ((match = regex.exec(trimmed)) !== null) { + matches.push({ + fullMatch: match[0], + varName: match[1], + index: match.index, + }); + } + + // Process matches in reverse order to avoid position shifts + for (let i = matches.length - 1; i >= 0; i--) { + const { fullMatch, varName, index } = matches[i]; + const envValue = process.env[varName] || fullMatch; + + // Replace at exact position + result = result.substring(0, index) + envValue + result.substring(index + fullMatch.length); + } + + return result; +} From 2e63e32382f5b9b257cc1d79742c04046c012d5f Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 2 Mar 2025 13:19:53 -0500 Subject: [PATCH 04/27] =?UTF-8?q?=F0=9F=90=BC=20feat:=20Add=20Flux=20Image?= =?UTF-8?q?=20Generation=20Tool=20(#6147)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 fix: Log warning for aborted operations in AgentClient * ci: Remove unused saveMessageToDatabase mock in FakeClient initialization * ci: test actual implementation of saveMessageToDatabase * refactor: Change log level from warning to error for aborted operations in AgentClient * refactor: Add className prop to Image component for customizable styling, use theme selectors * feat: FLUX Image Generation tool --- .env.example | 7 + api/app/clients/specs/BaseClient.test.js | 158 ++++- api/app/clients/specs/FakeClient.js | 2 - api/app/clients/tools/index.js | 4 +- api/app/clients/tools/manifest.json | 14 + api/app/clients/tools/structured/FluxAPI.js | 554 ++++++++++++++++++ api/app/clients/tools/util/handleTools.js | 5 +- api/server/controllers/agents/client.js | 9 +- .../Chat/Messages/Content/Image.tsx | 9 +- .../Messages/Content/Parts/Attachment.tsx | 8 +- package-lock.json | 2 +- packages/data-provider/package.json | 2 +- packages/data-provider/src/config.ts | 2 +- 13 files changed, 760 insertions(+), 16 deletions(-) create mode 100644 api/app/clients/tools/structured/FluxAPI.js diff --git a/.env.example b/.env.example index a1ab8e8485..e235b6cbb9 100644 --- a/.env.example +++ b/.env.example @@ -248,6 +248,13 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT= # DALLE3_AZURE_API_VERSION= # DALLE2_AZURE_API_VERSION= +# Flux +#----------------- +FLUX_API_BASE_URL=https://api.us1.bfl.ai +# FLUX_API_BASE_URL = 'https://api.bfl.ml'; + +# Get your API key at https://api.us1.bfl.ai/auth/profile +# FLUX_API_KEY= # Google #----------------- diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index e899449fb9..0dae5b14d3 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -30,6 +30,8 @@ jest.mock('~/models', () => ({ updateFileUsage: jest.fn(), })); +const { getConvo, saveConvo } = require('~/models'); + jest.mock('@langchain/openai', () => { return { ChatOpenAI: jest.fn().mockImplementation(() => { @@ -540,10 +542,11 @@ describe('BaseClient', () => { test('saveMessageToDatabase is called with the correct arguments', async () => { const saveOptions = TestClient.getSaveOptions(); - const user = {}; // Mock user + const user = {}; const opts = { user }; + const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase'); await TestClient.sendMessage('Hello, world!', opts); - expect(TestClient.saveMessageToDatabase).toHaveBeenCalledWith( + expect(saveSpy).toHaveBeenCalledWith( expect.objectContaining({ sender: expect.any(String), text: expect.any(String), @@ -557,6 +560,157 @@ describe('BaseClient', () => { ); }); + test('should handle existing conversation when getConvo retrieves one', async () => { + const existingConvo = { + conversationId: 'existing-convo-id', + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-3.5-turbo', + messages: [ + { role: 'user', content: 'Existing message 1' }, + { role: 'assistant', content: 'Existing response 1' }, + ], + temperature: 1, + }; + + const { temperature: _temp, ...newConvo } = existingConvo; + + const user = { + id: 'user-id', + }; + + getConvo.mockResolvedValue(existingConvo); + saveConvo.mockResolvedValue(newConvo); + + TestClient = initializeFakeClient( + apiKey, + { + ...options, + req: { + user, + }, + }, + [], + ); + + const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase'); + + const newMessage = 'New message in existing conversation'; + const response = await TestClient.sendMessage(newMessage, { + user, + conversationId: existingConvo.conversationId, + }); + + expect(getConvo).toHaveBeenCalledWith(user.id, existingConvo.conversationId); + expect(TestClient.conversationId).toBe(existingConvo.conversationId); + expect(response.conversationId).toBe(existingConvo.conversationId); + expect(TestClient.fetchedConvo).toBe(true); + + expect(saveSpy).toHaveBeenCalledWith( + expect.objectContaining({ + conversationId: existingConvo.conversationId, + text: newMessage, + }), + expect.any(Object), + expect.any(Object), + ); + + expect(saveConvo).toHaveBeenCalledTimes(2); + expect(saveConvo).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + conversationId: existingConvo.conversationId, + }), + expect.objectContaining({ + context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo', + unsetFields: { + temperature: 1, + }, + }), + ); + + await TestClient.sendMessage('Another message', { + conversationId: existingConvo.conversationId, + }); + expect(getConvo).toHaveBeenCalledTimes(1); + }); + + test('should correctly handle existing conversation and unset fields appropriately', async () => { + const existingConvo = { + conversationId: 'existing-convo-id', + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-3.5-turbo', + messages: [ + { role: 'user', content: 'Existing message 1' }, + { role: 'assistant', content: 'Existing response 1' }, + ], + title: 'Existing Conversation', + someExistingField: 'existingValue', + anotherExistingField: 'anotherValue', + temperature: 0.7, + modelLabel: 'GPT-3.5', + }; + + getConvo.mockResolvedValue(existingConvo); + saveConvo.mockResolvedValue(existingConvo); + + TestClient = initializeFakeClient( + apiKey, + { + ...options, + modelOptions: { + model: 'gpt-4', + temperature: 0.5, + }, + }, + [], + ); + + const newMessage = 'New message in existing conversation'; + await TestClient.sendMessage(newMessage, { + conversationId: existingConvo.conversationId, + }); + + expect(saveConvo).toHaveBeenCalledTimes(2); + + const saveConvoCall = saveConvo.mock.calls[0]; + const [, savedFields, saveOptions] = saveConvoCall; + + // Instead of checking all excludedKeys, we'll just check specific fields + // that we know should be excluded + expect(savedFields).not.toHaveProperty('messages'); + expect(savedFields).not.toHaveProperty('title'); + + // Only check that someExistingField is in unsetFields + expect(saveOptions.unsetFields).toHaveProperty('someExistingField', 1); + + // Mock saveConvo to return the expected fields + saveConvo.mockImplementation((req, fields) => { + return Promise.resolve({ + ...fields, + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-4', + temperature: 0.5, + }); + }); + + // Only check the conversationId since that's the only field we can be sure about + expect(savedFields).toHaveProperty('conversationId', 'existing-convo-id'); + + expect(TestClient.fetchedConvo).toBe(true); + + await TestClient.sendMessage('Another message', { + conversationId: existingConvo.conversationId, + }); + + expect(getConvo).toHaveBeenCalledTimes(1); + + const secondSaveConvoCall = saveConvo.mock.calls[1]; + expect(secondSaveConvoCall[2]).toHaveProperty('unsetFields', {}); + }); + test('sendCompletion is called with the correct arguments', async () => { const payload = {}; // Mock payload TestClient.buildMessages.mockReturnValue({ prompt: payload, tokenCountMap: null }); diff --git a/api/app/clients/specs/FakeClient.js b/api/app/clients/specs/FakeClient.js index 7f4b75e1db..a466bb97f9 100644 --- a/api/app/clients/specs/FakeClient.js +++ b/api/app/clients/specs/FakeClient.js @@ -56,7 +56,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { let TestClient = new FakeClient(apiKey); TestClient.options = options; TestClient.abortController = { abort: jest.fn() }; - TestClient.saveMessageToDatabase = jest.fn(); TestClient.loadHistory = jest .fn() .mockImplementation((conversationId, parentMessageId = null) => { @@ -86,7 +85,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { return 'Mock response text'; }); - // eslint-disable-next-line no-unused-vars TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => { return { choices: [ diff --git a/api/app/clients/tools/index.js b/api/app/clients/tools/index.js index b8df50c77d..df436fb089 100644 --- a/api/app/clients/tools/index.js +++ b/api/app/clients/tools/index.js @@ -2,9 +2,10 @@ const availableTools = require('./manifest.json'); // Structured Tools const DALLE3 = require('./structured/DALLE3'); +const FluxAPI = require('./structured/FluxAPI'); const OpenWeather = require('./structured/OpenWeather'); -const createYouTubeTools = require('./structured/YouTube'); const StructuredWolfram = require('./structured/Wolfram'); +const createYouTubeTools = require('./structured/YouTube'); const StructuredACS = require('./structured/AzureAISearch'); const StructuredSD = require('./structured/StableDiffusion'); const GoogleSearchAPI = require('./structured/GoogleSearch'); @@ -30,6 +31,7 @@ module.exports = { manifestToolMap, // Structured Tools DALLE3, + FluxAPI, OpenWeather, StructuredSD, StructuredACS, diff --git a/api/app/clients/tools/manifest.json b/api/app/clients/tools/manifest.json index 7cb92b8d87..43be7a4e6c 100644 --- a/api/app/clients/tools/manifest.json +++ b/api/app/clients/tools/manifest.json @@ -164,5 +164,19 @@ "description": "Sign up at OpenWeather, then get your key at API keys." } ] + }, + { + "name": "Flux", + "pluginKey": "flux", + "description": "Generate images using text with the Flux API.", + "icon": "https://blackforestlabs.ai/wp-content/uploads/2024/07/bfl_logo_retraced_blk.png", + "isAuthRequired": "true", + "authConfig": [ + { + "authField": "FLUX_API_KEY", + "label": "Your Flux API Key", + "description": "Provide your Flux API key from your user profile." + } + ] } ] diff --git a/api/app/clients/tools/structured/FluxAPI.js b/api/app/clients/tools/structured/FluxAPI.js new file mode 100644 index 0000000000..80f9772200 --- /dev/null +++ b/api/app/clients/tools/structured/FluxAPI.js @@ -0,0 +1,554 @@ +const { z } = require('zod'); +const axios = require('axios'); +const fetch = require('node-fetch'); +const { v4: uuidv4 } = require('uuid'); +const { Tool } = require('@langchain/core/tools'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); +const { logger } = require('~/config'); + +const displayMessage = + 'Flux displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + +/** + * FluxAPI - A tool for generating high-quality images from text prompts using the Flux API. + * Each call generates one image. If multiple images are needed, make multiple consecutive calls with the same or varied prompts. + */ +class FluxAPI extends Tool { + // Pricing constants in USD per image + static PRICING = { + FLUX_PRO_1_1_ULTRA: -0.06, // /v1/flux-pro-1.1-ultra + FLUX_PRO_1_1: -0.04, // /v1/flux-pro-1.1 + FLUX_PRO: -0.05, // /v1/flux-pro + FLUX_DEV: -0.025, // /v1/flux-dev + FLUX_PRO_FINETUNED: -0.06, // /v1/flux-pro-finetuned + FLUX_PRO_1_1_ULTRA_FINETUNED: -0.07, // /v1/flux-pro-1.1-ultra-finetuned + }; + + constructor(fields = {}) { + super(); + + /** @type {boolean} Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + + this.userId = fields.userId; + this.fileStrategy = fields.fileStrategy; + + /** @type {boolean} **/ + this.isAgent = fields.isAgent; + this.returnMetadata = fields.returnMetadata ?? false; + + if (fields.processFileURL) { + /** @type {processFileURL} Necessary for output to contain all image metadata. */ + this.processFileURL = fields.processFileURL.bind(this); + } + + this.apiKey = fields.FLUX_API_KEY || this.getApiKey(); + + this.name = 'flux'; + this.description = + 'Use Flux to generate images from text descriptions. This tool can generate images and list available finetunes. Each generate call creates one image. For multiple images, make multiple consecutive calls.'; + + this.description_for_model = `// Transform any image description into a detailed, high-quality prompt. Never submit a prompt under 3 sentences. Follow these core rules: + // 1. ALWAYS enhance basic prompts into 5-10 detailed sentences (e.g., "a cat" becomes: "A close-up photo of a sleek Siamese cat with piercing blue eyes. The cat sits elegantly on a vintage leather armchair, its tail curled gracefully around its paws. Warm afternoon sunlight streams through a nearby window, casting gentle shadows across its face and highlighting the subtle variations in its cream and chocolate-point fur. The background is softly blurred, creating a shallow depth of field that draws attention to the cat's expressive features. The overall composition has a peaceful, contemplative mood with a professional photography style.") + // 2. Each prompt MUST be 3-6 descriptive sentences minimum, focusing on visual elements: lighting, composition, mood, and style + // Use action: 'list_finetunes' to see available custom models. When using finetunes, use endpoint: '/v1/flux-pro-finetuned' (default) or '/v1/flux-pro-1.1-ultra-finetuned' for higher quality and aspect ratio.`; + + // Add base URL from environment variable with fallback + this.baseUrl = process.env.FLUX_API_BASE_URL || 'https://api.us1.bfl.ai'; + + // Define the schema for structured input + this.schema = z.object({ + action: z + .enum(['generate', 'list_finetunes', 'generate_finetuned']) + .default('generate') + .describe( + 'Action to perform: "generate" for image generation, "generate_finetuned" for finetuned model generation, "list_finetunes" to get available custom models', + ), + prompt: z + .string() + .optional() + .describe( + 'Text prompt for image generation. Required when action is "generate". Not used for list_finetunes.', + ), + width: z + .number() + .optional() + .describe( + 'Width of the generated image in pixels. Must be a multiple of 32. Default is 1024.', + ), + height: z + .number() + .optional() + .describe( + 'Height of the generated image in pixels. Must be a multiple of 32. Default is 768.', + ), + prompt_upsampling: z + .boolean() + .optional() + .default(false) + .describe('Whether to perform upsampling on the prompt.'), + steps: z + .number() + .int() + .optional() + .describe('Number of steps to run the model for, a number from 1 to 50. Default is 40.'), + seed: z.number().optional().describe('Optional seed for reproducibility.'), + safety_tolerance: z + .number() + .optional() + .default(6) + .describe( + 'Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ), + endpoint: z + .enum([ + '/v1/flux-pro-1.1', + '/v1/flux-pro', + '/v1/flux-dev', + '/v1/flux-pro-1.1-ultra', + '/v1/flux-pro-finetuned', + '/v1/flux-pro-1.1-ultra-finetuned', + ]) + .optional() + .default('/v1/flux-pro-1.1') + .describe('Endpoint to use for image generation.'), + raw: z + .boolean() + .optional() + .default(false) + .describe( + 'Generate less processed, more natural-looking images. Only works for /v1/flux-pro-1.1-ultra.', + ), + finetune_id: z.string().optional().describe('ID of the finetuned model to use'), + finetune_strength: z + .number() + .optional() + .default(1.1) + .describe('Strength of the finetuning effect (typically between 0.1 and 1.2)'), + guidance: z.number().optional().default(2.5).describe('Guidance scale for finetuned models'), + aspect_ratio: z + .string() + .optional() + .default('16:9') + .describe('Aspect ratio for ultra models (e.g., "16:9")'), + }); + } + + getAxiosConfig() { + const config = {}; + if (process.env.PROXY) { + config.httpsAgent = new HttpsProxyAgent(process.env.PROXY); + } + return config; + } + + /** @param {Object|string} value */ + getDetails(value) { + if (typeof value === 'string') { + return value; + } + return JSON.stringify(value, null, 2); + } + + getApiKey() { + const apiKey = process.env.FLUX_API_KEY || ''; + if (!apiKey && !this.override) { + throw new Error('Missing FLUX_API_KEY environment variable.'); + } + return apiKey; + } + + wrapInMarkdown(imageUrl) { + const serverDomain = process.env.DOMAIN_SERVER || 'http://localhost:3080'; + return `![generated image](${serverDomain}${imageUrl})`; + } + + returnValue(value) { + if (this.isAgent === true && typeof value === 'string') { + return [value, {}]; + } else if (this.isAgent === true && typeof value === 'object') { + if (Array.isArray(value)) { + return value; + } + return [displayMessage, value]; + } + return value; + } + + async _call(data) { + const { action = 'generate', ...imageData } = data; + + // Use provided API key for this request if available, otherwise use default + const requestApiKey = this.apiKey || this.getApiKey(); + + // Handle list_finetunes action + if (action === 'list_finetunes') { + return this.getMyFinetunes(requestApiKey); + } + + // Handle finetuned generation + if (action === 'generate_finetuned') { + return this.generateFinetunedImage(imageData, requestApiKey); + } + + // For generate action, ensure prompt is provided + if (!imageData.prompt) { + throw new Error('Missing required field: prompt'); + } + + let payload = { + prompt: imageData.prompt, + prompt_upsampling: imageData.prompt_upsampling || false, + safety_tolerance: imageData.safety_tolerance || 6, + output_format: imageData.output_format || 'png', + }; + + // Add optional parameters if provided + if (imageData.width) { + payload.width = imageData.width; + } + if (imageData.height) { + payload.height = imageData.height; + } + if (imageData.steps) { + payload.steps = imageData.steps; + } + if (imageData.seed !== undefined) { + payload.seed = imageData.seed; + } + if (imageData.raw) { + payload.raw = imageData.raw; + } + + const generateUrl = `${this.baseUrl}${imageData.endpoint || '/v1/flux-pro'}`; + const resultUrl = `${this.baseUrl}/v1/get_result`; + + logger.debug('[FluxAPI] Generating image with payload:', payload); + logger.debug('[FluxAPI] Using endpoint:', generateUrl); + + let taskResponse; + try { + taskResponse = await axios.post(generateUrl, payload, { + headers: { + 'x-key': requestApiKey, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + ...this.getAxiosConfig(), + }); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while submitting task:', details); + + return this.returnValue( + `Something went wrong when trying to generate the image. The Flux API may be unavailable: + Error Message: ${details}`, + ); + } + + const taskId = taskResponse.data.id; + + // Polling for the result + let status = 'Pending'; + let resultData = null; + while (status !== 'Ready' && status !== 'Error') { + try { + // Wait 2 seconds between polls + await new Promise((resolve) => setTimeout(resolve, 2000)); + const resultResponse = await axios.get(resultUrl, { + headers: { + 'x-key': requestApiKey, + Accept: 'application/json', + }, + params: { id: taskId }, + ...this.getAxiosConfig(), + }); + status = resultResponse.data.status; + + if (status === 'Ready') { + resultData = resultResponse.data.result; + break; + } else if (status === 'Error') { + logger.error('[FluxAPI] Error in task:', resultResponse.data); + return this.returnValue('An error occurred during image generation.'); + } + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting result:', details); + return this.returnValue('An error occurred while retrieving the image.'); + } + } + + // If no result data + if (!resultData || !resultData.sample) { + logger.error('[FluxAPI] No image data received from API. Response:', resultData); + return this.returnValue('No image data received from Flux API.'); + } + + // Try saving the image locally + const imageUrl = resultData.sample; + const imageName = `img-${uuidv4()}.png`; + + if (this.isAgent) { + try { + // Fetch the image and convert to base64 + const fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(imageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } catch (error) { + logger.error('Error processing image for agent:', error); + return this.returnValue(`Failed to process the image. ${error.message}`); + } + } + + try { + logger.debug('[FluxAPI] Saving image:', imageUrl); + const result = await this.processFileURL({ + fileStrategy: this.fileStrategy, + userId: this.userId, + URL: imageUrl, + fileName: imageName, + basePath: 'images', + context: FileContext.image_generation, + }); + + logger.debug('[FluxAPI] Image saved to path:', result.filepath); + + // Calculate cost based on endpoint + /** + * TODO: Cost handling + const endpoint = imageData.endpoint || '/v1/flux-pro'; + const endpointKey = Object.entries(FluxAPI.PRICING).find(([key, _]) => + endpoint.includes(key.toLowerCase().replace(/_/g, '-')), + )?.[0]; + const cost = FluxAPI.PRICING[endpointKey] || 0; + */ + this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); + return this.returnValue(this.result); + } catch (error) { + const details = this.getDetails(error?.message ?? 'No additional error details.'); + logger.error('Error while saving the image:', details); + return this.returnValue(`Failed to save the image locally. ${details}`); + } + } + + async getMyFinetunes(apiKey = null) { + const finetunesUrl = `${this.baseUrl}/v1/my_finetunes`; + const detailsUrl = `${this.baseUrl}/v1/finetune_details`; + + try { + const headers = { + 'x-key': apiKey || this.getApiKey(), + 'Content-Type': 'application/json', + Accept: 'application/json', + }; + + // Get list of finetunes + const response = await axios.get(finetunesUrl, { + headers, + ...this.getAxiosConfig(), + }); + const finetunes = response.data.finetunes; + + // Fetch details for each finetune + const finetuneDetails = await Promise.all( + finetunes.map(async (finetuneId) => { + try { + const detailResponse = await axios.get(`${detailsUrl}?finetune_id=${finetuneId}`, { + headers, + ...this.getAxiosConfig(), + }); + return { + id: finetuneId, + ...detailResponse.data, + }; + } catch (error) { + logger.error(`[FluxAPI] Error fetching details for finetune ${finetuneId}:`, error); + return { + id: finetuneId, + error: 'Failed to fetch details', + }; + } + }), + ); + + if (this.isAgent) { + const formattedDetails = JSON.stringify(finetuneDetails, null, 2); + return [`Here are the available finetunes:\n${formattedDetails}`, null]; + } + return JSON.stringify(finetuneDetails); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting finetunes:', details); + const errorMsg = `Failed to get finetunes: ${details}`; + return this.isAgent ? this.returnValue([errorMsg, {}]) : new Error(errorMsg); + } + } + + async generateFinetunedImage(imageData, requestApiKey) { + if (!imageData.prompt) { + throw new Error('Missing required field: prompt'); + } + + if (!imageData.finetune_id) { + throw new Error( + 'Missing required field: finetune_id for finetuned generation. Please supply a finetune_id!', + ); + } + + // Validate endpoint is appropriate for finetuned generation + const validFinetunedEndpoints = ['/v1/flux-pro-finetuned', '/v1/flux-pro-1.1-ultra-finetuned']; + const endpoint = imageData.endpoint || '/v1/flux-pro-finetuned'; + + if (!validFinetunedEndpoints.includes(endpoint)) { + throw new Error( + `Invalid endpoint for finetuned generation. Must be one of: ${validFinetunedEndpoints.join(', ')}`, + ); + } + + let payload = { + prompt: imageData.prompt, + prompt_upsampling: imageData.prompt_upsampling || false, + safety_tolerance: imageData.safety_tolerance || 6, + output_format: imageData.output_format || 'png', + finetune_id: imageData.finetune_id, + finetune_strength: imageData.finetune_strength || 1.0, + guidance: imageData.guidance || 2.5, + }; + + // Add optional parameters if provided + if (imageData.width) { + payload.width = imageData.width; + } + if (imageData.height) { + payload.height = imageData.height; + } + if (imageData.steps) { + payload.steps = imageData.steps; + } + if (imageData.seed !== undefined) { + payload.seed = imageData.seed; + } + if (imageData.raw) { + payload.raw = imageData.raw; + } + + const generateUrl = `${this.baseUrl}${endpoint}`; + const resultUrl = `${this.baseUrl}/v1/get_result`; + + logger.debug('[FluxAPI] Generating finetuned image with payload:', payload); + logger.debug('[FluxAPI] Using endpoint:', generateUrl); + + let taskResponse; + try { + taskResponse = await axios.post(generateUrl, payload, { + headers: { + 'x-key': requestApiKey, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + ...this.getAxiosConfig(), + }); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while submitting finetuned task:', details); + return this.returnValue( + `Something went wrong when trying to generate the finetuned image. The Flux API may be unavailable: + Error Message: ${details}`, + ); + } + + const taskId = taskResponse.data.id; + + // Polling for the result + let status = 'Pending'; + let resultData = null; + while (status !== 'Ready' && status !== 'Error') { + try { + // Wait 2 seconds between polls + await new Promise((resolve) => setTimeout(resolve, 2000)); + const resultResponse = await axios.get(resultUrl, { + headers: { + 'x-key': requestApiKey, + Accept: 'application/json', + }, + params: { id: taskId }, + ...this.getAxiosConfig(), + }); + status = resultResponse.data.status; + + if (status === 'Ready') { + resultData = resultResponse.data.result; + break; + } else if (status === 'Error') { + logger.error('[FluxAPI] Error in finetuned task:', resultResponse.data); + return this.returnValue('An error occurred during finetuned image generation.'); + } + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting finetuned result:', details); + return this.returnValue('An error occurred while retrieving the finetuned image.'); + } + } + + // If no result data + if (!resultData || !resultData.sample) { + logger.error('[FluxAPI] No image data received from API. Response:', resultData); + return this.returnValue('No image data received from Flux API.'); + } + + // Try saving the image locally + const imageUrl = resultData.sample; + const imageName = `img-${uuidv4()}.png`; + + try { + logger.debug('[FluxAPI] Saving finetuned image:', imageUrl); + const result = await this.processFileURL({ + fileStrategy: this.fileStrategy, + userId: this.userId, + URL: imageUrl, + fileName: imageName, + basePath: 'images', + context: FileContext.image_generation, + }); + + logger.debug('[FluxAPI] Finetuned image saved to path:', result.filepath); + + // Calculate cost based on endpoint + const endpointKey = endpoint.includes('ultra') + ? 'FLUX_PRO_1_1_ULTRA_FINETUNED' + : 'FLUX_PRO_FINETUNED'; + const cost = FluxAPI.PRICING[endpointKey] || 0; + // Return the result based on returnMetadata flag + this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); + return this.returnValue(this.result); + } catch (error) { + const details = this.getDetails(error?.message ?? 'No additional error details.'); + logger.error('Error while saving the finetuned image:', details); + return this.returnValue(`Failed to save the finetuned image locally. ${details}`); + } + } +} + +module.exports = FluxAPI; diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index f1dfa24a49..ae19a158ee 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -10,6 +10,7 @@ const { GoogleSearchAPI, // Structured Tools DALLE3, + FluxAPI, OpenWeather, StructuredSD, StructuredACS, @@ -182,6 +183,7 @@ const loadTools = async ({ returnMap = false, }) => { const toolConstructors = { + flux: FluxAPI, calculator: Calculator, google: GoogleSearchAPI, open_weather: OpenWeather, @@ -230,9 +232,10 @@ const loadTools = async ({ }; const toolOptions = { - serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, + flux: imageGenOptions, dalle: imageGenOptions, 'stable-diffusion': imageGenOptions, + serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, }; const toolContextMap = {}; diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 99d64bb9a6..b50314901f 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -795,6 +795,10 @@ class AgentClient extends BaseClient { ); } } catch (err) { + logger.error( + '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', + err, + ); if (!abortController.signal.aborted) { logger.error( '[api/server/controllers/agents/client.js #sendCompletion] Unhandled error type', @@ -802,11 +806,6 @@ class AgentClient extends BaseClient { ); throw err; } - - logger.warn( - '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', - err, - ); } } diff --git a/client/src/components/Chat/Messages/Content/Image.tsx b/client/src/components/Chat/Messages/Content/Image.tsx index 28910d0315..41ee52453f 100644 --- a/client/src/components/Chat/Messages/Content/Image.tsx +++ b/client/src/components/Chat/Messages/Content/Image.tsx @@ -29,6 +29,7 @@ const Image = ({ height, width, placeholderDimensions, + className, }: { imagePath: string; altText: string; @@ -38,6 +39,7 @@ const Image = ({ height?: string; width?: string; }; + className?: string; }) => { const [isLoaded, setIsLoaded] = useState(false); const containerRef = useRef(null); @@ -57,7 +59,12 @@ const Image = ({ return (
-
+