From 0352067da2b076d3f18b570f7f692e8dd9b1a1ae Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 20 Sep 2025 08:19:44 -0400 Subject: [PATCH 1/4] =?UTF-8?q?=F0=9F=8E=A8=20refactor:=20Improve=20Mermai?= =?UTF-8?q?d=20Artifacts=20Styling=20(#9742)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🎨 refactor: Improve Mermaid Artifacts Styling * refactor: Replace ArtifactMarkdown with MermaidMarkdown --- .../components/Artifacts/ArtifactPreview.tsx | 14 +++++++++- .../src/components/Artifacts/ArtifactTabs.tsx | 27 ++++++++++--------- client/src/components/Artifacts/Artifacts.tsx | 17 +++++++----- .../components/Artifacts/MermaidMarkdown.tsx | 11 ++++++++ 4 files changed, 49 insertions(+), 20 deletions(-) create mode 100644 client/src/components/Artifacts/MermaidMarkdown.tsx diff --git a/client/src/components/Artifacts/ArtifactPreview.tsx b/client/src/components/Artifacts/ArtifactPreview.tsx index 5c9920c13..d5114ceaf 100644 --- a/client/src/components/Artifacts/ArtifactPreview.tsx +++ b/client/src/components/Artifacts/ArtifactPreview.tsx @@ -4,7 +4,7 @@ import { SandpackProvider, SandpackProviderProps, } from '@codesandbox/sandpack-react/unstyled'; -import type { SandpackPreviewRef } from '@codesandbox/sandpack-react/unstyled'; +import type { SandpackPreviewRef, PreviewProps } from '@codesandbox/sandpack-react/unstyled'; import type { TStartupConfig } from 'librechat-data-provider'; import type { ArtifactFiles } from '~/common'; import { sharedFiles, sharedOptions } from '~/utils/artifacts'; @@ -13,6 +13,7 @@ export const ArtifactPreview = memo(function ({ files, fileKey, template, + isMermaid, sharedProps, previewRef, currentCode, @@ -20,6 +21,7 @@ export const ArtifactPreview = memo(function ({ }: { files: ArtifactFiles; fileKey: string; + isMermaid: boolean; template: SandpackProviderProps['template']; sharedProps: Partial; previewRef: React.MutableRefObject; @@ -54,6 +56,15 @@ export const ArtifactPreview = memo(function ({ return _options; }, [startupConfig, template]); + const style: PreviewProps['style'] | undefined = useMemo(() => { + if (isMermaid) { + return { + backgroundColor: '#282C34', + }; + } + return; + }, [isMermaid]); + if (Object.keys(artifactFiles).length === 0) { return null; } @@ -73,6 +84,7 @@ export const ArtifactPreview = memo(function ({ showRefreshButton={false} tabIndex={0} ref={previewRef} + style={style} /> ); diff --git a/client/src/components/Artifacts/ArtifactTabs.tsx b/client/src/components/Artifacts/ArtifactTabs.tsx index a8792410d..cd8c441ad 100644 --- a/client/src/components/Artifacts/ArtifactTabs.tsx +++ b/client/src/components/Artifacts/ArtifactTabs.tsx @@ -8,6 +8,7 @@ import { useAutoScroll } from '~/hooks/Artifacts/useAutoScroll'; import { ArtifactCodeEditor } from './ArtifactCodeEditor'; import { useGetStartupConfig } from '~/data-provider'; import { ArtifactPreview } from './ArtifactPreview'; +import { MermaidMarkdown } from './MermaidMarkdown'; import { cn } from '~/utils'; export default function ArtifactTabs({ @@ -44,23 +45,25 @@ export default function ArtifactTabs({ id="artifacts-code" className={cn('flex-grow overflow-auto')} > - + {isMermaid ? ( + + ) : ( + + )} - + {/* Main Container */}
{/* Header */}
@@ -74,16 +76,17 @@ export default function Artifacts() { {/* Refresh button */} {activeTab === 'preview' && ( )} diff --git a/client/src/components/Artifacts/MermaidMarkdown.tsx b/client/src/components/Artifacts/MermaidMarkdown.tsx new file mode 100644 index 000000000..780b0d74d --- /dev/null +++ b/client/src/components/Artifacts/MermaidMarkdown.tsx @@ -0,0 +1,11 @@ +import { CodeMarkdown } from './Code'; + +export function MermaidMarkdown({ + content, + isSubmitting, +}: { + content: string; + isSubmitting: boolean; +}) { + return ; +} From 2489670f54e74ec411281fc6460188050c526759 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 20 Sep 2025 10:17:24 -0400 Subject: [PATCH 2/4] =?UTF-8?q?=F0=9F=93=82=20refactor:=20File=20Read=20Op?= =?UTF-8?q?erations=20(#9747)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: axios response logging for text parsing, remove console logging, remove jsdoc * refactor: error logging in logAxiosError function to handle various error types with type guards * refactor: enhance text parsing with improved error handling and async file reading * refactor: replace synchronous file reading with asynchronous methods for improved performance and memory management * ci: update tests --- api/server/utils/sendEmail.js | 6 +- packages/api/src/files/mistral/crud.spec.ts | 21 +- packages/api/src/files/mistral/crud.ts | 9 +- packages/api/src/files/text.spec.ts | 91 ++-- packages/api/src/files/text.ts | 65 +-- .../api/src/utils/__tests__/files.test.ts | 414 ++++++++++++++++++ packages/api/src/utils/axios.ts | 27 +- packages/api/src/utils/files.ts | 121 +++++ packages/api/src/utils/key.test.ts | 17 +- packages/api/src/utils/key.ts | 4 +- 10 files changed, 692 insertions(+), 83 deletions(-) create mode 100644 packages/api/src/utils/__tests__/files.test.ts diff --git a/api/server/utils/sendEmail.js b/api/server/utils/sendEmail.js index ee64b209f..432a571ff 100644 --- a/api/server/utils/sendEmail.js +++ b/api/server/utils/sendEmail.js @@ -1,11 +1,10 @@ -const fs = require('fs'); const path = require('path'); const axios = require('axios'); const FormData = require('form-data'); const nodemailer = require('nodemailer'); const handlebars = require('handlebars'); const { logger } = require('@librechat/data-schemas'); -const { logAxiosError, isEnabled } = require('@librechat/api'); +const { logAxiosError, isEnabled, readFileAsString } = require('@librechat/api'); /** * Sends an email using Mailgun API. @@ -93,8 +92,7 @@ const sendEmailViaSMTP = async ({ transporterOptions, mailOptions }) => { */ const sendEmail = async ({ email, subject, payload, template, throwError = true }) => { try { - // Read and compile the email template - const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8'); + const { content: source } = await readFileAsString(path.join(__dirname, 'emails', template)); const compiledTemplate = handlebars.compile(source); const html = compiledTemplate(payload); diff --git a/packages/api/src/files/mistral/crud.spec.ts b/packages/api/src/files/mistral/crud.spec.ts index 688a553ff..955678106 100644 --- a/packages/api/src/files/mistral/crud.spec.ts +++ b/packages/api/src/files/mistral/crud.spec.ts @@ -45,6 +45,10 @@ jest.mock('~/utils/axios', () => ({ logAxiosError: jest.fn(({ message }) => message || 'Error'), })); +jest.mock('~/utils/files', () => ({ + readFileAsBuffer: jest.fn(), +})); + import * as fs from 'fs'; import axios from 'axios'; import { HttpsProxyAgent } from 'https-proxy-agent'; @@ -56,6 +60,7 @@ import type { OCRResult, } from '~/types'; import { logger as mockLogger } from '@librechat/data-schemas'; +import { readFileAsBuffer } from '~/utils/files'; import { uploadDocumentToMistral, uploadAzureMistralOCR, @@ -1978,9 +1983,10 @@ describe('MistralOCR Service', () => { describe('Azure Mistral OCR with proxy', () => { beforeEach(() => { - (jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue( - Buffer.from('mock-file-content'), - ); + (readFileAsBuffer as jest.Mock).mockResolvedValue({ + content: Buffer.from('mock-file-content'), + bytes: Buffer.from('mock-file-content').length, + }); }); it('should use proxy for Azure Mistral OCR requests', async () => { @@ -2098,7 +2104,10 @@ describe('MistralOCR Service', () => { describe('uploadAzureMistralOCR', () => { beforeEach(() => { - (jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(Buffer.from('mock-file-content')); + (readFileAsBuffer as jest.Mock).mockResolvedValue({ + content: Buffer.from('mock-file-content'), + bytes: Buffer.from('mock-file-content').length, + }); // Reset the HttpsProxyAgent mock to its default implementation for Azure tests (HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url })); // Clean up any PROXY env var from previous tests @@ -2172,7 +2181,9 @@ describe('MistralOCR Service', () => { loadAuthValues: mockLoadAuthValues, }); - expect(jest.mocked(fs).readFileSync).toHaveBeenCalledWith('/tmp/upload/azure-file.pdf'); + expect(readFileAsBuffer).toHaveBeenCalledWith('/tmp/upload/azure-file.pdf', { + fileSize: undefined, + }); // Verify OCR was called with base64 data URL expect(mockAxios.post).toHaveBeenCalledWith( diff --git a/packages/api/src/files/mistral/crud.ts b/packages/api/src/files/mistral/crud.ts index dda29216f..e0ef69ab3 100644 --- a/packages/api/src/files/mistral/crud.ts +++ b/packages/api/src/files/mistral/crud.ts @@ -22,6 +22,7 @@ import type { OCRImage, } from '~/types'; import { logAxiosError, createAxiosInstance } from '~/utils/axios'; +import { readFileAsBuffer } from '~/utils/files'; import { loadServiceKey } from '~/utils/key'; const axios = createAxiosInstance(); @@ -464,7 +465,9 @@ export const uploadAzureMistralOCR = async ( const { apiKey, baseURL } = await loadAuthConfig(context); const model = getModelConfig(context.req.config?.ocr); - const buffer = fs.readFileSync(context.file.path); + const { content: buffer } = await readFileAsBuffer(context.file.path, { + fileSize: context.file.size, + }); const base64 = buffer.toString('base64'); /** Uses actual mimetype of the file, 'image/jpeg' as fallback since it seems to be accepted regardless of mismatch */ const base64Prefix = `data:${context.file.mimetype || 'image/jpeg'};base64,`; @@ -691,7 +694,9 @@ export const uploadGoogleVertexMistralOCR = async ( const { serviceAccount, accessToken } = await loadGoogleAuthConfig(); const model = getModelConfig(context.req.config?.ocr); - const buffer = fs.readFileSync(context.file.path); + const { content: buffer } = await readFileAsBuffer(context.file.path, { + fileSize: context.file.size, + }); const base64 = buffer.toString('base64'); const base64Prefix = `data:${context.file.mimetype || 'application/pdf'};base64,`; diff --git a/packages/api/src/files/text.spec.ts b/packages/api/src/files/text.spec.ts index d1a1dad89..1de553d8b 100644 --- a/packages/api/src/files/text.spec.ts +++ b/packages/api/src/files/text.spec.ts @@ -9,8 +9,6 @@ jest.mock('@librechat/data-schemas', () => ({ }, })); -import { parseTextNative, parseText } from './text'; - jest.mock('fs', () => ({ readFileSync: jest.fn(), createReadStream: jest.fn(), @@ -36,10 +34,24 @@ jest.mock('form-data', () => { })); }); +// Mock the utils module to avoid AWS SDK issues +jest.mock('../utils', () => ({ + logAxiosError: jest.fn((args) => { + if (typeof args === 'object' && args.message) { + return args.message; + } + return 'Error'; + }), + readFileAsString: jest.fn(), +})); + +// Now import everything after mocks are in place +import { parseTextNative, parseText } from './text'; import fs, { ReadStream } from 'fs'; import axios from 'axios'; import FormData from 'form-data'; import { generateShortLivedToken } from '../crypto/jwt'; +import { readFileAsString } from '../utils'; const mockedFs = fs as jest.Mocked; const mockedAxios = axios as jest.Mocked; @@ -47,6 +59,7 @@ const mockedFormData = FormData as jest.MockedClass; const mockedGenerateShortLivedToken = generateShortLivedToken as jest.MockedFunction< typeof generateShortLivedToken >; +const mockedReadFileAsString = readFileAsString as jest.MockedFunction; describe('text', () => { const mockFile: Express.Multer.File = { @@ -74,29 +87,32 @@ describe('text', () => { }); describe('parseTextNative', () => { - it('should successfully parse a text file', () => { + it('should successfully parse a text file', async () => { const mockText = 'Hello, world!'; - mockedFs.readFileSync.mockReturnValue(mockText); + const mockBytes = Buffer.byteLength(mockText, 'utf8'); - const result = parseTextNative(mockFile); + mockedReadFileAsString.mockResolvedValue({ + content: mockText, + bytes: mockBytes, + }); - expect(mockedFs.readFileSync).toHaveBeenCalledWith('/tmp/test.txt', 'utf8'); + const result = await parseTextNative(mockFile); + + expect(mockedReadFileAsString).toHaveBeenCalledWith('/tmp/test.txt', { + fileSize: 100, + }); expect(result).toEqual({ text: mockText, - bytes: Buffer.byteLength(mockText, 'utf8'), + bytes: mockBytes, source: FileSources.text, }); }); - it('should throw an error when file cannot be read', () => { + it('should handle file read errors', async () => { const mockError = new Error('File not found'); - mockedFs.readFileSync.mockImplementation(() => { - throw mockError; - }); + mockedReadFileAsString.mockRejectedValue(mockError); - expect(() => parseTextNative(mockFile)).toThrow( - 'Failed to read file as text: Error: File not found', - ); + await expect(parseTextNative(mockFile)).rejects.toThrow('File not found'); }); }); @@ -115,7 +131,12 @@ describe('text', () => { it('should fall back to native parsing when RAG_API_URL is not defined', async () => { const mockText = 'Native parsing result'; - mockedFs.readFileSync.mockReturnValue(mockText); + const mockBytes = Buffer.byteLength(mockText, 'utf8'); + + mockedReadFileAsString.mockResolvedValue({ + content: mockText, + bytes: mockBytes, + }); const result = await parseText({ req: mockReq, @@ -125,7 +146,7 @@ describe('text', () => { expect(result).toEqual({ text: mockText, - bytes: Buffer.byteLength(mockText, 'utf8'), + bytes: mockBytes, source: FileSources.text, }); expect(mockedAxios.get).not.toHaveBeenCalled(); @@ -134,7 +155,12 @@ describe('text', () => { it('should fall back to native parsing when health check fails', async () => { process.env.RAG_API_URL = 'http://rag-api.test'; const mockText = 'Native parsing result'; - mockedFs.readFileSync.mockReturnValue(mockText); + const mockBytes = Buffer.byteLength(mockText, 'utf8'); + + mockedReadFileAsString.mockResolvedValue({ + content: mockText, + bytes: mockBytes, + }); mockedAxios.get.mockRejectedValue(new Error('Health check failed')); @@ -145,11 +171,11 @@ describe('text', () => { }); expect(mockedAxios.get).toHaveBeenCalledWith('http://rag-api.test/health', { - timeout: 5000, + timeout: 10000, }); expect(result).toEqual({ text: mockText, - bytes: Buffer.byteLength(mockText, 'utf8'), + bytes: mockBytes, source: FileSources.text, }); }); @@ -157,7 +183,12 @@ describe('text', () => { it('should fall back to native parsing when health check returns non-OK status', async () => { process.env.RAG_API_URL = 'http://rag-api.test'; const mockText = 'Native parsing result'; - mockedFs.readFileSync.mockReturnValue(mockText); + const mockBytes = Buffer.byteLength(mockText, 'utf8'); + + mockedReadFileAsString.mockResolvedValue({ + content: mockText, + bytes: mockBytes, + }); mockedAxios.get.mockResolvedValue({ status: 500, @@ -172,7 +203,7 @@ describe('text', () => { expect(result).toEqual({ text: mockText, - bytes: Buffer.byteLength(mockText, 'utf8'), + bytes: mockBytes, source: FileSources.text, }); }); @@ -207,7 +238,12 @@ describe('text', () => { it('should fall back to native parsing when RAG API response lacks text property', async () => { process.env.RAG_API_URL = 'http://rag-api.test'; const mockText = 'Native parsing result'; - mockedFs.readFileSync.mockReturnValue(mockText); + const mockBytes = Buffer.byteLength(mockText, 'utf8'); + + mockedReadFileAsString.mockResolvedValue({ + content: mockText, + bytes: mockBytes, + }); mockedAxios.get.mockResolvedValue({ status: 200, @@ -226,7 +262,7 @@ describe('text', () => { expect(result).toEqual({ text: mockText, - bytes: Buffer.byteLength(mockText, 'utf8'), + bytes: mockBytes, source: FileSources.text, }); }); @@ -234,7 +270,12 @@ describe('text', () => { it('should fall back to native parsing when user is undefined', async () => { process.env.RAG_API_URL = 'http://rag-api.test'; const mockText = 'Native parsing result'; - mockedFs.readFileSync.mockReturnValue(mockText); + const mockBytes = Buffer.byteLength(mockText, 'utf8'); + + mockedReadFileAsString.mockResolvedValue({ + content: mockText, + bytes: mockBytes, + }); const result = await parseText({ req: { user: undefined }, @@ -247,7 +288,7 @@ describe('text', () => { expect(mockedAxios.post).not.toHaveBeenCalled(); expect(result).toEqual({ text: mockText, - bytes: Buffer.byteLength(mockText, 'utf8'), + bytes: mockBytes, source: FileSources.text, }); }); diff --git a/packages/api/src/files/text.ts b/packages/api/src/files/text.ts index 41b4ca0ab..06e781bb5 100644 --- a/packages/api/src/files/text.ts +++ b/packages/api/src/files/text.ts @@ -1,18 +1,19 @@ -import fs from 'fs'; import axios from 'axios'; import FormData from 'form-data'; +import { createReadStream } from 'fs'; import { logger } from '@librechat/data-schemas'; import { FileSources } from 'librechat-data-provider'; import type { Request as ServerRequest } from 'express'; +import { logAxiosError, readFileAsString } from '~/utils'; import { generateShortLivedToken } from '~/crypto/jwt'; /** * Attempts to parse text using RAG API, falls back to native text parsing - * @param {Object} params - The parameters object - * @param {Express.Request} params.req - The Express request object - * @param {Express.Multer.File} params.file - The uploaded file - * @param {string} params.file_id - The file ID - * @returns {Promise<{text: string, bytes: number, source: string}>} + * @param params - The parameters object + * @param params.req - The Express request object + * @param params.file - The uploaded file + * @param params.file_id - The file ID + * @returns */ export async function parseText({ req, @@ -30,32 +31,33 @@ export async function parseText({ return parseTextNative(file); } - if (!req.user?.id) { + const userId = req.user?.id; + if (!userId) { logger.debug('[parseText] No user ID provided, falling back to native text parsing'); return parseTextNative(file); } try { const healthResponse = await axios.get(`${process.env.RAG_API_URL}/health`, { - timeout: 5000, + timeout: 10000, }); if (healthResponse?.statusText !== 'OK' && healthResponse?.status !== 200) { logger.debug('[parseText] RAG API health check failed, falling back to native parsing'); return parseTextNative(file); } } catch (healthError) { - logger.debug( - '[parseText] RAG API health check failed, falling back to native parsing', - healthError, - ); + logAxiosError({ + message: '[parseText] RAG API health check failed, falling back to native parsing:', + error: healthError, + }); return parseTextNative(file); } try { - const jwtToken = generateShortLivedToken(req.user.id); + const jwtToken = generateShortLivedToken(userId); const formData = new FormData(); formData.append('file_id', file_id); - formData.append('file', fs.createReadStream(file.path)); + formData.append('file', createReadStream(file.path)); const formHeaders = formData.getHeaders(); @@ -69,7 +71,7 @@ export async function parseText({ }); const responseData = response.data; - logger.debug('[parseText] Response from RAG API', responseData); + logger.debug(`[parseText] RAG API completed successfully (${response.status})`); if (!('text' in responseData)) { throw new Error('RAG API did not return parsed text'); @@ -81,7 +83,10 @@ export async function parseText({ source: FileSources.text, }; } catch (error) { - logger.warn('[parseText] RAG API text parsing failed, falling back to native parsing', error); + logAxiosError({ + message: '[parseText] RAG API text parsing failed, falling back to native parsing', + error, + }); return parseTextNative(file); } } @@ -89,25 +94,21 @@ export async function parseText({ /** * Native JavaScript text parsing fallback * Simple text file reading - complex formats handled by RAG API - * @param {Express.Multer.File} file - The uploaded file - * @returns {{text: string, bytes: number, source: string}} + * @param file - The uploaded file + * @returns */ -export function parseTextNative(file: Express.Multer.File): { +export async function parseTextNative(file: Express.Multer.File): Promise<{ text: string; bytes: number; source: string; -} { - try { - const text = fs.readFileSync(file.path, 'utf8'); - const bytes = Buffer.byteLength(text, 'utf8'); +}> { + const { content: text, bytes } = await readFileAsString(file.path, { + fileSize: file.size, + }); - return { - text, - bytes, - source: FileSources.text, - }; - } catch (error) { - console.error('[parseTextNative] Failed to parse file:', error); - throw new Error(`Failed to read file as text: ${error}`); - } + return { + text, + bytes, + source: FileSources.text, + }; } diff --git a/packages/api/src/utils/__tests__/files.test.ts b/packages/api/src/utils/__tests__/files.test.ts new file mode 100644 index 000000000..41332d263 --- /dev/null +++ b/packages/api/src/utils/__tests__/files.test.ts @@ -0,0 +1,414 @@ +import { createReadStream } from 'fs'; +import { readFile, stat } from 'fs/promises'; +import { Readable } from 'stream'; +import { readFileAsString, readFileAsBuffer, readJsonFile } from '../files'; + +jest.mock('fs'); +jest.mock('fs/promises'); + +describe('File utilities', () => { + const mockFilePath = '/test/file.txt'; + const smallContent = 'Hello, World!'; + const largeContent = 'x'.repeat(11 * 1024 * 1024); // 11MB of 'x' + + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('readFileAsString', () => { + it('should read small files directly without streaming', async () => { + const fileSize = Buffer.byteLength(smallContent); + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + (readFile as jest.Mock).mockResolvedValue(smallContent); + + const result = await readFileAsString(mockFilePath); + + expect(result).toEqual({ + content: smallContent, + bytes: fileSize, + }); + expect(stat).toHaveBeenCalledWith(mockFilePath); + expect(readFile).toHaveBeenCalledWith(mockFilePath, 'utf8'); + expect(createReadStream).not.toHaveBeenCalled(); + }); + + it('should use provided fileSize to avoid stat call', async () => { + const fileSize = Buffer.byteLength(smallContent); + + (readFile as jest.Mock).mockResolvedValue(smallContent); + + const result = await readFileAsString(mockFilePath, { fileSize }); + + expect(result).toEqual({ + content: smallContent, + bytes: fileSize, + }); + expect(stat).not.toHaveBeenCalled(); + expect(readFile).toHaveBeenCalledWith(mockFilePath, 'utf8'); + }); + + it('should stream large files', async () => { + const fileSize = Buffer.byteLength(largeContent); + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + + // Create a mock readable stream + const chunks = [ + largeContent.substring(0, 5000000), + largeContent.substring(5000000, 10000000), + largeContent.substring(10000000), + ]; + + const mockStream = new Readable({ + read() { + if (chunks.length > 0) { + this.push(chunks.shift()); + } else { + this.push(null); // End stream + } + }, + }); + + (createReadStream as jest.Mock).mockReturnValue(mockStream); + + const result = await readFileAsString(mockFilePath); + + expect(result).toEqual({ + content: largeContent, + bytes: fileSize, + }); + expect(stat).toHaveBeenCalledWith(mockFilePath); + expect(createReadStream).toHaveBeenCalledWith(mockFilePath, { + encoding: 'utf8', + highWaterMark: 64 * 1024, + }); + expect(readFile).not.toHaveBeenCalled(); + }); + + it('should use custom encoding', async () => { + const fileSize = 100; + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + (readFile as jest.Mock).mockResolvedValue(smallContent); + + await readFileAsString(mockFilePath, { encoding: 'latin1' }); + + expect(readFile).toHaveBeenCalledWith(mockFilePath, 'latin1'); + }); + + it('should respect custom stream threshold', async () => { + const customThreshold = 1024; // 1KB + const fileSize = 2048; // 2KB + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + + const mockStream = new Readable({ + read() { + this.push('test content'); + this.push(null); + }, + }); + + (createReadStream as jest.Mock).mockReturnValue(mockStream); + + await readFileAsString(mockFilePath, { streamThreshold: customThreshold }); + + expect(createReadStream).toHaveBeenCalled(); + expect(readFile).not.toHaveBeenCalled(); + }); + + it('should handle empty files', async () => { + const fileSize = 0; + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + (readFile as jest.Mock).mockResolvedValue(''); + + const result = await readFileAsString(mockFilePath); + + expect(result).toEqual({ + content: '', + bytes: 0, + }); + }); + + it('should propagate read errors', async () => { + const error = new Error('File not found'); + + (stat as jest.Mock).mockResolvedValue({ size: 100 }); + (readFile as jest.Mock).mockRejectedValue(error); + + await expect(readFileAsString(mockFilePath)).rejects.toThrow('File not found'); + }); + + it('should propagate stat errors when fileSize not provided', async () => { + const error = new Error('Permission denied'); + + (stat as jest.Mock).mockRejectedValue(error); + + await expect(readFileAsString(mockFilePath)).rejects.toThrow('Permission denied'); + }); + + it('should propagate stream errors', async () => { + const fileSize = 11 * 1024 * 1024; // 11MB + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + + const mockStream = new Readable({ + read() { + this.emit('error', new Error('Stream error')); + }, + }); + + (createReadStream as jest.Mock).mockReturnValue(mockStream); + + await expect(readFileAsString(mockFilePath)).rejects.toThrow('Stream error'); + }); + }); + + describe('readFileAsBuffer', () => { + const smallBuffer = Buffer.from(smallContent); + const largeBuffer = Buffer.from(largeContent); + + it('should read small files directly without streaming', async () => { + const fileSize = smallBuffer.length; + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + (readFile as jest.Mock).mockResolvedValue(smallBuffer); + + const result = await readFileAsBuffer(mockFilePath); + + expect(result).toEqual({ + content: smallBuffer, + bytes: fileSize, + }); + expect(stat).toHaveBeenCalledWith(mockFilePath); + expect(readFile).toHaveBeenCalledWith(mockFilePath); + expect(createReadStream).not.toHaveBeenCalled(); + }); + + it('should use provided fileSize to avoid stat call', async () => { + const fileSize = smallBuffer.length; + + (readFile as jest.Mock).mockResolvedValue(smallBuffer); + + const result = await readFileAsBuffer(mockFilePath, { fileSize }); + + expect(result).toEqual({ + content: smallBuffer, + bytes: fileSize, + }); + expect(stat).not.toHaveBeenCalled(); + }); + + it('should stream large files', async () => { + const fileSize = largeBuffer.length; + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + + // Split large buffer into chunks + const chunk1 = largeBuffer.slice(0, 5000000); + const chunk2 = largeBuffer.slice(5000000, 10000000); + const chunk3 = largeBuffer.slice(10000000); + + const chunks = [chunk1, chunk2, chunk3]; + + const mockStream = new Readable({ + read() { + if (chunks.length > 0) { + this.push(chunks.shift()); + } else { + this.push(null); + } + }, + }); + + (createReadStream as jest.Mock).mockReturnValue(mockStream); + + const result = await readFileAsBuffer(mockFilePath); + + expect(result.bytes).toBe(fileSize); + expect(Buffer.compare(result.content, largeBuffer)).toBe(0); + expect(createReadStream).toHaveBeenCalledWith(mockFilePath, { + highWaterMark: 64 * 1024, + }); + expect(readFile).not.toHaveBeenCalled(); + }); + + it('should respect custom highWaterMark', async () => { + const fileSize = 11 * 1024 * 1024; // 11MB + const customHighWaterMark = 128 * 1024; // 128KB + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + + const mockStream = new Readable({ + read() { + this.push(Buffer.from('test')); + this.push(null); + }, + }); + + (createReadStream as jest.Mock).mockReturnValue(mockStream); + + await readFileAsBuffer(mockFilePath, { highWaterMark: customHighWaterMark }); + + expect(createReadStream).toHaveBeenCalledWith(mockFilePath, { + highWaterMark: customHighWaterMark, + }); + }); + + it('should handle empty buffer files', async () => { + const emptyBuffer = Buffer.alloc(0); + + (stat as jest.Mock).mockResolvedValue({ size: 0 }); + (readFile as jest.Mock).mockResolvedValue(emptyBuffer); + + const result = await readFileAsBuffer(mockFilePath); + + expect(result).toEqual({ + content: emptyBuffer, + bytes: 0, + }); + }); + + it('should propagate errors', async () => { + const error = new Error('Access denied'); + + (stat as jest.Mock).mockResolvedValue({ size: 100 }); + (readFile as jest.Mock).mockRejectedValue(error); + + await expect(readFileAsBuffer(mockFilePath)).rejects.toThrow('Access denied'); + }); + }); + + describe('readJsonFile', () => { + const validJson = { name: 'test', value: 123, nested: { key: 'value' } }; + const jsonString = JSON.stringify(validJson); + + it('should parse valid JSON files', async () => { + (stat as jest.Mock).mockResolvedValue({ size: jsonString.length }); + (readFile as jest.Mock).mockResolvedValue(jsonString); + + const result = await readJsonFile(mockFilePath); + + expect(result).toEqual(validJson); + expect(readFile).toHaveBeenCalledWith(mockFilePath, 'utf8'); + }); + + it('should parse JSON with provided fileSize', async () => { + const fileSize = jsonString.length; + + (readFile as jest.Mock).mockResolvedValue(jsonString); + + const result = await readJsonFile(mockFilePath, { fileSize }); + + expect(result).toEqual(validJson); + expect(stat).not.toHaveBeenCalled(); + }); + + it('should handle JSON arrays', async () => { + const jsonArray = [1, 2, 3, { key: 'value' }]; + const arrayString = JSON.stringify(jsonArray); + + (stat as jest.Mock).mockResolvedValue({ size: arrayString.length }); + (readFile as jest.Mock).mockResolvedValue(arrayString); + + const result = await readJsonFile(mockFilePath); + + expect(result).toEqual(jsonArray); + }); + + it('should throw on invalid JSON', async () => { + const invalidJson = '{ invalid json }'; + + (stat as jest.Mock).mockResolvedValue({ size: invalidJson.length }); + (readFile as jest.Mock).mockResolvedValue(invalidJson); + + await expect(readJsonFile(mockFilePath)).rejects.toThrow(); + }); + + it('should throw on empty file', async () => { + (stat as jest.Mock).mockResolvedValue({ size: 0 }); + (readFile as jest.Mock).mockResolvedValue(''); + + await expect(readJsonFile(mockFilePath)).rejects.toThrow(); + }); + + it('should handle large JSON files with streaming', async () => { + const largeJson = { data: 'x'.repeat(11 * 1024 * 1024) }; // >10MB + const largeJsonString = JSON.stringify(largeJson); + const fileSize = largeJsonString.length; + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + + // Create chunks for streaming + const chunks: string[] = []; + let offset = 0; + const chunkSize = 5 * 1024 * 1024; // 5MB chunks + + while (offset < largeJsonString.length) { + chunks.push(largeJsonString.slice(offset, offset + chunkSize)); + offset += chunkSize; + } + + const mockStream = new Readable({ + read() { + if (chunks.length > 0) { + this.push(chunks.shift()); + } else { + this.push(null); + } + }, + }); + + (createReadStream as jest.Mock).mockReturnValue(mockStream); + + const result = await readJsonFile(mockFilePath); + + expect(result).toEqual(largeJson); + expect(createReadStream).toHaveBeenCalled(); + expect(readFile).not.toHaveBeenCalled(); + }); + + it('should use custom stream threshold', async () => { + const customThreshold = 100; + const json = { test: 'x'.repeat(200) }; + const jsonStr = JSON.stringify(json); + const fileSize = jsonStr.length; + + (stat as jest.Mock).mockResolvedValue({ size: fileSize }); + + const mockStream = new Readable({ + read() { + this.push(jsonStr); + this.push(null); + }, + }); + + (createReadStream as jest.Mock).mockReturnValue(mockStream); + + await readJsonFile(mockFilePath, { streamThreshold: customThreshold }); + + expect(createReadStream).toHaveBeenCalled(); + }); + + it('should preserve type with generics', async () => { + interface TestType { + id: number; + name: string; + } + + const typedJson: TestType = { id: 1, name: 'test' }; + const jsonString = JSON.stringify(typedJson); + + (stat as jest.Mock).mockResolvedValue({ size: jsonString.length }); + (readFile as jest.Mock).mockResolvedValue(jsonString); + + const result = await readJsonFile(mockFilePath); + + expect(result).toEqual(typedJson); + expect(result.id).toBe(1); + expect(result.name).toBe('test'); + }); + }); +}); diff --git a/packages/api/src/utils/axios.ts b/packages/api/src/utils/axios.ts index d1275ada4..5d73955f0 100644 --- a/packages/api/src/utils/axios.ts +++ b/packages/api/src/utils/axios.ts @@ -9,12 +9,25 @@ import type { AxiosInstance, AxiosProxyConfig, AxiosError } from 'axios'; * @param options.error - The Axios error object. * @returns The log message. */ -export const logAxiosError = ({ message, error }: { message: string; error: AxiosError }) => { +export const logAxiosError = ({ + message, + error, +}: { + message: string; + error: AxiosError | Error | unknown; +}) => { let logMessage = message; try { - const stack = error.stack || 'No stack trace available'; + const stack = + error != null + ? (error as Error | AxiosError)?.stack || 'No stack trace available' + : 'No stack trace available'; + const errorMessage = + error != null + ? (error as Error | AxiosError)?.message || 'No error message available' + : 'No error message available'; - if (error.response?.status) { + if (axios.isAxiosError(error) && error.response && error.response?.status) { const { status, headers, data } = error.response; logMessage = `${message} The server responded with status ${status}: ${error.message}`; logger.error(logMessage, { @@ -23,18 +36,18 @@ export const logAxiosError = ({ message, error }: { message: string; error: Axio data, stack, }); - } else if (error.request) { + } else if (axios.isAxiosError(error) && error.request) { const { method, url } = error.config || {}; logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`; logger.error(logMessage, { requestInfo: { method, url }, stack, }); - } else if (error?.message?.includes("Cannot read properties of undefined (reading 'status')")) { - logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`; + } else if (errorMessage?.includes("Cannot read properties of undefined (reading 'status')")) { + logMessage = `${message} It appears the request timed out or was unsuccessful: ${errorMessage}`; logger.error(logMessage, { stack }); } else { - logMessage = `${message} An error occurred while setting up the request: ${error.message}`; + logMessage = `${message} An error occurred while setting up the request: ${errorMessage}`; logger.error(logMessage, { stack }); } } catch (err: unknown) { diff --git a/packages/api/src/utils/files.ts b/packages/api/src/utils/files.ts index 0d8f111a2..2fa3b62ab 100644 --- a/packages/api/src/utils/files.ts +++ b/packages/api/src/utils/files.ts @@ -1,5 +1,7 @@ import path from 'path'; import crypto from 'node:crypto'; +import { createReadStream } from 'fs'; +import { readFile, stat } from 'fs/promises'; /** * Sanitize a filename by removing any directory components, replacing non-alphanumeric characters @@ -31,3 +33,122 @@ export function sanitizeFilename(inputName: string): string { return name; } + +/** + * Options for reading files + */ +export interface ReadFileOptions { + encoding?: BufferEncoding; + /** Size threshold in bytes. Files larger than this will be streamed. Default: 10MB */ + streamThreshold?: number; + /** Size of chunks when streaming. Default: 64KB */ + highWaterMark?: number; + /** File size in bytes if known (e.g. from multer). Avoids extra stat() call. */ + fileSize?: number; +} + +/** + * Result from reading a file + */ +export interface ReadFileResult { + content: T; + bytes: number; +} + +/** + * Reads a file asynchronously. Uses streaming for large files to avoid memory issues. + * + * @param filePath - Path to the file to read + * @param options - Options for reading the file + * @returns Promise resolving to the file contents and size + * @throws Error if the file cannot be read + */ +export async function readFileAsString( + filePath: string, + options: ReadFileOptions = {}, +): Promise> { + const { + encoding = 'utf8', + streamThreshold = 10 * 1024 * 1024, // 10MB + highWaterMark = 64 * 1024, // 64KB + fileSize, + } = options; + + // Get file size if not provided + const bytes = fileSize ?? (await stat(filePath)).size; + + // For large files, use streaming to avoid memory issues + if (bytes > streamThreshold) { + const chunks: string[] = []; + const stream = createReadStream(filePath, { + encoding, + highWaterMark, + }); + + for await (const chunk of stream) { + chunks.push(chunk as string); + } + + return { content: chunks.join(''), bytes }; + } + + // For smaller files, read directly + const content = await readFile(filePath, encoding); + return { content, bytes }; +} + +/** + * Reads a file as a Buffer asynchronously. Uses streaming for large files. + * + * @param filePath - Path to the file to read + * @param options - Options for reading the file + * @returns Promise resolving to the file contents and size + * @throws Error if the file cannot be read + */ +export async function readFileAsBuffer( + filePath: string, + options: Omit = {}, +): Promise> { + const { + streamThreshold = 10 * 1024 * 1024, // 10MB + highWaterMark = 64 * 1024, // 64KB + fileSize, + } = options; + + // Get file size if not provided + const bytes = fileSize ?? (await stat(filePath)).size; + + // For large files, use streaming to avoid memory issues + if (bytes > streamThreshold) { + const chunks: Buffer[] = []; + const stream = createReadStream(filePath, { + highWaterMark, + }); + + for await (const chunk of stream) { + chunks.push(chunk as Buffer); + } + + return { content: Buffer.concat(chunks), bytes }; + } + + // For smaller files, read directly + const content = await readFile(filePath); + return { content, bytes }; +} + +/** + * Reads a JSON file asynchronously + * + * @param filePath - Path to the JSON file to read + * @param options - Options for reading the file + * @returns Promise resolving to the parsed JSON object + * @throws Error if the file cannot be read or parsed + */ +export async function readJsonFile( + filePath: string, + options: Omit = {}, +): Promise { + const { content } = await readFileAsString(filePath, { ...options, encoding: 'utf8' }); + return JSON.parse(content); +} diff --git a/packages/api/src/utils/key.test.ts b/packages/api/src/utils/key.test.ts index 29f34adde..ae3de5270 100644 --- a/packages/api/src/utils/key.test.ts +++ b/packages/api/src/utils/key.test.ts @@ -1,6 +1,6 @@ -import fs from 'fs'; import path from 'path'; import axios from 'axios'; +import { readFileAsString } from './files'; import { loadServiceKey } from './key'; jest.mock('fs'); @@ -11,6 +11,10 @@ jest.mock('@librechat/data-schemas', () => ({ }, })); +jest.mock('./files', () => ({ + readFileAsString: jest.fn(), +})); + describe('loadServiceKey', () => { const mockServiceKey = { type: 'service_account', @@ -49,10 +53,13 @@ describe('loadServiceKey', () => { it('should load from file path', async () => { const filePath = '/path/to/service-key.json'; - (fs.readFileSync as jest.Mock).mockReturnValue(JSON.stringify(mockServiceKey)); + (readFileAsString as jest.Mock).mockResolvedValue({ + content: JSON.stringify(mockServiceKey), + bytes: JSON.stringify(mockServiceKey).length, + }); const result = await loadServiceKey(filePath); - expect(fs.readFileSync).toHaveBeenCalledWith(path.resolve(filePath), 'utf8'); + expect(readFileAsString).toHaveBeenCalledWith(path.resolve(filePath)); expect(result).toEqual(mockServiceKey); }); @@ -73,9 +80,7 @@ describe('loadServiceKey', () => { it('should handle file read errors', async () => { const filePath = '/path/to/nonexistent.json'; - (fs.readFileSync as jest.Mock).mockImplementation(() => { - throw new Error('File not found'); - }); + (readFileAsString as jest.Mock).mockRejectedValue(new Error('File not found')); const result = await loadServiceKey(filePath); expect(result).toBeNull(); diff --git a/packages/api/src/utils/key.ts b/packages/api/src/utils/key.ts index 086e74c06..13dabeaf5 100644 --- a/packages/api/src/utils/key.ts +++ b/packages/api/src/utils/key.ts @@ -1,7 +1,7 @@ -import fs from 'fs'; import path from 'path'; import axios from 'axios'; import { logger } from '@librechat/data-schemas'; +import { readFileAsString } from './files'; export interface GoogleServiceKey { type?: string; @@ -63,7 +63,7 @@ export async function loadServiceKey(keyPath: string): Promise Date: Sat, 20 Sep 2025 17:01:45 +0200 Subject: [PATCH 3/4] =?UTF-8?q?=F0=9F=94=90=20fix:=20Handle=20Multiple=20E?= =?UTF-8?q?mail=20Addresses=20in=20LDAP=20Auth=20(#9729)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/strategies/ldapStrategy.js | 3 +- api/strategies/ldapStrategy.spec.js | 186 ++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 api/strategies/ldapStrategy.spec.js diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index 2822bd8a2..17d54df4a 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -109,7 +109,8 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { const username = (LDAP_USERNAME && userinfo[LDAP_USERNAME]) || userinfo.givenName || userinfo.mail; - const mail = (LDAP_EMAIL && userinfo[LDAP_EMAIL]) || userinfo.mail || username + '@ldap.local'; + let mail = (LDAP_EMAIL && userinfo[LDAP_EMAIL]) || userinfo.mail || username + '@ldap.local'; + mail = Array.isArray(mail) ? mail[0] : mail; if (!userinfo.mail && !(LDAP_EMAIL && userinfo[LDAP_EMAIL])) { logger.warn( diff --git a/api/strategies/ldapStrategy.spec.js b/api/strategies/ldapStrategy.spec.js new file mode 100644 index 000000000..d3d51a4cd --- /dev/null +++ b/api/strategies/ldapStrategy.spec.js @@ -0,0 +1,186 @@ +// --- Mocks --- +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + // isEnabled used for TLS flags + isEnabled: jest.fn(() => false), + getBalanceConfig: jest.fn(() => ({ enabled: false })), +})); + +jest.mock('~/models', () => ({ + findUser: jest.fn(), + createUser: jest.fn(), + updateUser: jest.fn(), + countUsers: jest.fn(), +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn().mockResolvedValue({}), +})); + +jest.mock('~/server/services/domains', () => ({ + isEmailDomainAllowed: jest.fn(() => true), +})); + +// Mock passport-ldapauth to capture verify callback +let verifyCallback; +jest.mock('passport-ldapauth', () => { + return jest.fn().mockImplementation((options, verify) => { + verifyCallback = verify; // capture the strategy verify function + return { name: 'ldap', options, verify }; + }); +}); + +const { ErrorTypes } = require('librechat-data-provider'); +const { findUser, createUser, updateUser, countUsers } = require('~/models'); +const { isEmailDomainAllowed } = require('~/server/services/domains'); + +// Helper to call the verify callback and wrap in a Promise for convenience +const callVerify = (userinfo) => + new Promise((resolve, reject) => { + verifyCallback(userinfo, (err, user, info) => { + if (err) return reject(err); + resolve({ user, info }); + }); + }); + +describe('ldapStrategy', () => { + beforeEach(() => { + jest.clearAllMocks(); + + // minimal required env for ldapStrategy module to export + process.env.LDAP_URL = 'ldap://example.com'; + process.env.LDAP_USER_SEARCH_BASE = 'ou=users,dc=example,dc=com'; + + // Unset optional envs to exercise defaults + delete process.env.LDAP_CA_CERT_PATH; + delete process.env.LDAP_FULL_NAME; + delete process.env.LDAP_ID; + delete process.env.LDAP_USERNAME; + delete process.env.LDAP_EMAIL; + delete process.env.LDAP_TLS_REJECT_UNAUTHORIZED; + delete process.env.LDAP_STARTTLS; + + // Default model/domain mocks + findUser.mockReset().mockResolvedValue(null); + createUser.mockReset().mockResolvedValue('newUserId'); + updateUser.mockReset().mockImplementation(async (id, user) => ({ _id: id, ...user })); + countUsers.mockReset().mockResolvedValue(0); + isEmailDomainAllowed.mockReset().mockReturnValue(true); + + // Ensure requiring the strategy sets up the verify callback + jest.isolateModules(() => { + require('./ldapStrategy'); + }); + }); + + it('uses the first email when LDAP returns multiple emails (array)', async () => { + const userinfo = { + uid: 'uid123', + givenName: 'Alice', + cn: 'Alice Doe', + mail: ['first@example.com', 'second@example.com'], + }; + + const { user } = await callVerify(userinfo); + + expect(user.email).toBe('first@example.com'); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ + provider: 'ldap', + ldapId: 'uid123', + username: 'Alice', + email: 'first@example.com', + emailVerified: true, + name: 'Alice Doe', + }), + expect.any(Object), + ); + }); + + it('blocks login if an existing user has a different provider', async () => { + findUser.mockResolvedValue({ _id: 'u1', email: 'first@example.com', provider: 'google' }); + + const userinfo = { + uid: 'uid123', + mail: 'first@example.com', + givenName: 'Alice', + cn: 'Alice Doe', + }; + + const { user, info } = await callVerify(userinfo); + + expect(user).toBe(false); + expect(info).toEqual({ message: ErrorTypes.AUTH_FAILED }); + expect(createUser).not.toHaveBeenCalled(); + }); + + it('updates an existing ldap user with current LDAP info', async () => { + const existing = { + _id: 'u2', + provider: 'ldap', + email: 'old@example.com', + ldapId: 'uid123', + username: 'olduser', + name: 'Old Name', + }; + findUser.mockResolvedValue(existing); + + const userinfo = { + uid: 'uid123', + mail: 'new@example.com', + givenName: 'NewFirst', + cn: 'NewFirst NewLast', + }; + + const { user } = await callVerify(userinfo); + + expect(createUser).not.toHaveBeenCalled(); + expect(updateUser).toHaveBeenCalledWith( + 'u2', + expect.objectContaining({ + provider: 'ldap', + ldapId: 'uid123', + email: 'new@example.com', + username: 'NewFirst', + name: 'NewFirst NewLast', + }), + ); + expect(user.email).toBe('new@example.com'); + }); + + it('falls back to username@ldap.local when no email attributes are present', async () => { + const userinfo = { + uid: 'uid999', + givenName: 'John', + cn: 'John Doe', + // no mail and no custom LDAP_EMAIL + }; + + const { user } = await callVerify(userinfo); + + expect(user.email).toBe('John@ldap.local'); + }); + + it('denies login if email domain is not allowed', async () => { + isEmailDomainAllowed.mockReturnValue(false); + + const userinfo = { + uid: 'uid123', + mail: 'notallowed@blocked.com', + givenName: 'Alice', + cn: 'Alice Doe', + }; + + const { user, info } = await callVerify(userinfo); + expect(user).toBe(false); + expect(info).toEqual({ message: 'Email domain not allowed' }); + }); +}); From 9d2aba5df582d68a8b8529895dd0da265d21a17f Mon Sep 17 00:00:00 2001 From: Federico Ruggi Date: Sat, 20 Sep 2025 17:06:23 +0200 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20fix:=20Handle=20Nul?= =?UTF-8?q?l=20`MCPManager`=20In=20`OAuthReconnectionManager`=20(#9740)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../oauth/OAuthReconnectionManager.test.ts | 28 ++++++++++++++++ .../src/mcp/oauth/OAuthReconnectionManager.ts | 33 ++++++++++++++----- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts index 78fedb9c3..8f18df2f5 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts @@ -290,5 +290,33 @@ describe('OAuthReconnectionManager', () => { expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false); expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1'); }); + + it('should handle MCPManager not available gracefully', async () => { + const userId = 'user-123'; + + // Reset singleton first + (OAuthReconnectionManager as unknown as { instance: null }).instance = null; + + // Mock MCPManager.getInstance to throw (simulating no MCP manager available) + (MCPManager.getInstance as jest.Mock).mockImplementation(() => { + throw new Error('MCPManager has not been initialized.'); + }); + + // Create a reconnection manager without MCPManager available + const reconnectionTracker = new OAuthReconnectionTracker(); + const reconnectionManagerWithoutMCP = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + + // Verify that the method does not throw and completes successfully + await expect(reconnectionManagerWithoutMCP.reconnectServers(userId)).resolves.toBeUndefined(); + + // Verify that the method returns early without attempting any reconnections + expect(tokenMethods.findToken).not.toHaveBeenCalled(); + expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled(); + expect(mockMCPManager.disconnectUserConnection).not.toHaveBeenCalled(); + }); }); }); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts index 48b751dfa..b819403a6 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -13,6 +13,7 @@ export class OAuthReconnectionManager { protected readonly flowManager: FlowStateManager; protected readonly tokenMethods: TokenMethods; + private readonly mcpManager: MCPManager | null; private readonly reconnectionsTracker: OAuthReconnectionTracker; @@ -46,6 +47,12 @@ export class OAuthReconnectionManager { this.flowManager = flowManager; this.tokenMethods = tokenMethods; this.reconnectionsTracker = reconnections ?? new OAuthReconnectionTracker(); + + try { + this.mcpManager = MCPManager.getInstance(); + } catch { + this.mcpManager = null; + } } public isReconnecting(userId: string, serverName: string): boolean { @@ -53,11 +60,17 @@ export class OAuthReconnectionManager { } public async reconnectServers(userId: string) { - const mcpManager = MCPManager.getInstance(); + // Check if MCPManager is available + if (this.mcpManager == null) { + logger.warn( + '[OAuthReconnectionManager] MCPManager not available, skipping OAuth MCP server reconnection', + ); + return; + } // 1. derive the servers to reconnect const serversToReconnect = []; - for (const serverName of mcpManager.getOAuthServers() ?? []) { + for (const serverName of this.mcpManager.getOAuthServers() ?? []) { const canReconnect = await this.canReconnect(userId, serverName); if (canReconnect) { serversToReconnect.push(serverName); @@ -81,23 +94,25 @@ export class OAuthReconnectionManager { } private async tryReconnect(userId: string, serverName: string) { - const mcpManager = MCPManager.getInstance(); + if (this.mcpManager == null) { + return; + } const logPrefix = `[tryReconnectOAuthMCPServer][User: ${userId}][${serverName}]`; logger.info(`${logPrefix} Attempting reconnection`); - const config = mcpManager.getRawConfig(serverName); + const config = this.mcpManager.getRawConfig(serverName); const cleanupOnFailedReconnect = () => { this.reconnectionsTracker.setFailed(userId, serverName); this.reconnectionsTracker.removeActive(userId, serverName); - mcpManager.disconnectUserConnection(userId, serverName); + this.mcpManager?.disconnectUserConnection(userId, serverName); }; try { // attempt to get connection (this will use existing tokens and refresh if needed) - const connection = await mcpManager.getUserConnection({ + const connection = await this.mcpManager.getUserConnection({ serverName, user: { id: userId } as TUser, flowManager: this.flowManager, @@ -125,7 +140,9 @@ export class OAuthReconnectionManager { } private async canReconnect(userId: string, serverName: string) { - const mcpManager = MCPManager.getInstance(); + if (this.mcpManager == null) { + return false; + } // if the server has failed reconnection, don't attempt to reconnect if (this.reconnectionsTracker.isFailed(userId, serverName)) { @@ -133,7 +150,7 @@ export class OAuthReconnectionManager { } // if the server is already connected, don't attempt to reconnect - const existingConnections = mcpManager.getUserConnections(userId); + const existingConnections = this.mcpManager.getUserConnections(userId); if (existingConnections?.has(serverName)) { const isConnected = await existingConnections.get(serverName)?.isConnected(); if (isConnected) {