Merge branch 'dev' into feat/prompt-enhancement

This commit is contained in:
Marco Beretta 2025-06-23 14:48:46 +02:00 committed by GitHub
commit 3d261a969d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
365 changed files with 23826 additions and 8790 deletions

View file

@ -1,8 +1,10 @@
const express = require('express');
const jwt = require('jsonwebtoken');
const { getAccessToken } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { getAccessToken } = require('~/server/services/TokenService');
const { logger, getFlowStateManager } = require('~/config');
const { findToken, updateToken, createToken } = require('~/models');
const { getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache');
const router = express.Router();
@ -28,18 +30,19 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
try {
decodedState = jwt.verify(state, JWT_SECRET);
} catch (err) {
logger.error('Error verifying state parameter:', err);
await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter');
return res.status(400).send('Invalid or expired state parameter');
return res.redirect('/oauth/error?error=invalid_state');
}
if (decodedState.action_id !== action_id) {
await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter');
return res.status(400).send('Mismatched action ID in state parameter');
return res.redirect('/oauth/error?error=invalid_state');
}
if (!decodedState.user) {
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
return res.status(400).send('Invalid user ID in state parameter');
return res.redirect('/oauth/error?error=invalid_state');
}
identifier = `${decodedState.user}:${action_id}`;
const flowState = await flowManager.getFlowState(identifier, 'oauth');
@ -47,90 +50,34 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
throw new Error('OAuth flow not found');
}
const tokenData = await getAccessToken({
code,
userId: decodedState.user,
identifier,
client_url: flowState.metadata.client_url,
redirect_uri: flowState.metadata.redirect_uri,
/** Encrypted values */
encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id,
encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret,
});
const tokenData = await getAccessToken(
{
code,
userId: decodedState.user,
identifier,
client_url: flowState.metadata.client_url,
redirect_uri: flowState.metadata.redirect_uri,
token_exchange_method: flowState.metadata.token_exchange_method,
/** Encrypted values */
encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id,
encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret,
},
{
findToken,
updateToken,
createToken,
},
);
await flowManager.completeFlow(identifier, 'oauth', tokenData);
res.send(`
<!DOCTYPE html>
<html>
<head>
<title>Authentication Successful</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<style>
body {
font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont;
background-color: rgb(249, 250, 251);
margin: 0;
padding: 2rem;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
}
.card {
background-color: white;
border-radius: 0.5rem;
padding: 2rem;
max-width: 28rem;
width: 100%;
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
text-align: center;
}
.heading {
color: rgb(17, 24, 39);
font-size: 1.875rem;
font-weight: 700;
margin: 0 0 1rem;
}
.description {
color: rgb(75, 85, 99);
font-size: 0.875rem;
margin: 0.5rem 0;
}
.countdown {
color: rgb(99, 102, 241);
font-weight: 500;
}
</style>
</head>
<body>
<div class="card">
<h1 class="heading">Authentication Successful</h1>
<p class="description">
Your authentication was successful. This window will close in
<span class="countdown" id="countdown">3</span> seconds.
</p>
</div>
<script>
let secondsLeft = 3;
const countdownElement = document.getElementById('countdown');
const countdown = setInterval(() => {
secondsLeft--;
countdownElement.textContent = secondsLeft;
if (secondsLeft <= 0) {
clearInterval(countdown);
window.close();
}
}, 1000);
</script>
</body>
</html>
`);
/** Redirect to React success page */
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
res.redirect(redirectUrl);
} catch (error) {
logger.error('Error in OAuth callback:', error);
await flowManager.failFlow(identifier, 'oauth', error);
res.status(500).send('Authentication failed. Please try again.');
res.redirect('/oauth/error?error=callback_failed');
}
});

View file

@ -1,10 +1,11 @@
const express = require('express');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider');
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
const { getLdapConfig } = require('~/server/services/Config/ldap');
const { getProjectByName } = require('~/models/Project');
const { isEnabled } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
const router = express.Router();
const emailLoginEnabled =
@ -21,6 +22,7 @@ const publicSharedLinksEnabled =
router.get('/', async function (req, res) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
if (cachedStartupConfig) {
res.send(cachedStartupConfig);
@ -96,6 +98,18 @@ router.get('/', async function (req, res) {
bundlerURL: process.env.SANDPACK_BUNDLER_URL,
staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL,
};
payload.mcpServers = {};
const config = await getCustomConfig();
if (config?.mcpServers != null) {
for (const serverName in config.mcpServers) {
const serverConfig = config.mcpServers[serverName];
payload.mcpServers[serverName] = {
customUserVars: serverConfig?.customUserVars || {},
};
}
}
/** @type {TCustomConfig['webSearch']} */
const webSearchConfig = req.app.locals.webSearch;
if (

View file

@ -65,8 +65,14 @@ router.post('/gen_title', async (req, res) => {
let title = await titleCache.get(key);
if (!title) {
await sleep(2500);
title = await titleCache.get(key);
// Retry every 1s for up to 20s
for (let i = 0; i < 20; i++) {
await sleep(1000);
title = await titleCache.get(key);
if (title) {
break;
}
}
}
if (title) {

View file

@ -2,8 +2,8 @@ const fs = require('fs');
const path = require('path');
const crypto = require('crypto');
const multer = require('multer');
const { sanitizeFilename } = require('@librechat/api');
const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider');
const { sanitizeFilename } = require('~/server/utils/handleText');
const { getCustomConfig } = require('~/server/services/Config');
const storage = multer.diskStorage({

View file

@ -0,0 +1,571 @@
/* eslint-disable no-unused-vars */
/* eslint-disable jest/no-done-callback */
const fs = require('fs');
const os = require('os');
const path = require('path');
const crypto = require('crypto');
const { createMulterInstance, storage, importFileFilter } = require('./multer');
// Mock only the config service that requires external dependencies
jest.mock('~/server/services/Config', () => ({
getCustomConfig: jest.fn(() =>
Promise.resolve({
fileConfig: {
endpoints: {
openAI: {
supportedMimeTypes: ['image/jpeg', 'image/png', 'application/pdf'],
},
default: {
supportedMimeTypes: ['image/jpeg', 'image/png', 'text/plain'],
},
},
serverFileSizeLimit: 10000000, // 10MB
},
}),
),
}));
describe('Multer Configuration', () => {
let tempDir;
let mockReq;
let mockFile;
beforeEach(() => {
// Create a temporary directory for each test
tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'multer-test-'));
mockReq = {
user: { id: 'test-user-123' },
app: {
locals: {
paths: {
uploads: tempDir,
},
},
},
body: {},
originalUrl: '/api/files/upload',
};
mockFile = {
originalname: 'test-file.jpg',
mimetype: 'image/jpeg',
size: 1024,
};
// Clear mocks
jest.clearAllMocks();
});
afterEach(() => {
// Clean up temporary directory
if (fs.existsSync(tempDir)) {
fs.rmSync(tempDir, { recursive: true, force: true });
}
});
describe('Storage Configuration', () => {
describe('destination function', () => {
it('should create the correct destination path', (done) => {
const cb = jest.fn((err, destination) => {
expect(err).toBeNull();
expect(destination).toBe(path.join(tempDir, 'temp', 'test-user-123'));
expect(fs.existsSync(destination)).toBe(true);
done();
});
storage.getDestination(mockReq, mockFile, cb);
});
it("should create directory recursively if it doesn't exist", (done) => {
const deepPath = path.join(tempDir, 'deep', 'nested', 'path');
mockReq.app.locals.paths.uploads = deepPath;
const cb = jest.fn((err, destination) => {
expect(err).toBeNull();
expect(destination).toBe(path.join(deepPath, 'temp', 'test-user-123'));
expect(fs.existsSync(destination)).toBe(true);
done();
});
storage.getDestination(mockReq, mockFile, cb);
});
});
describe('filename function', () => {
it('should generate a UUID for req.file_id', (done) => {
const cb = jest.fn((err, filename) => {
expect(err).toBeNull();
expect(mockReq.file_id).toBeDefined();
expect(mockReq.file_id).toMatch(
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i,
);
done();
});
storage.getFilename(mockReq, mockFile, cb);
});
it('should decode URI components in filename', (done) => {
const encodedFile = {
...mockFile,
originalname: encodeURIComponent('test file with spaces.jpg'),
};
const cb = jest.fn((err, filename) => {
expect(err).toBeNull();
expect(encodedFile.originalname).toBe('test file with spaces.jpg');
done();
});
storage.getFilename(mockReq, encodedFile, cb);
});
it('should call real sanitizeFilename with properly encoded filename', (done) => {
// Test with a properly URI-encoded filename that needs sanitization
const unsafeFile = {
...mockFile,
originalname: encodeURIComponent('test@#$%^&*()file with spaces!.jpg'),
};
const cb = jest.fn((err, filename) => {
expect(err).toBeNull();
// The actual sanitizeFilename should have cleaned this up after decoding
expect(filename).not.toContain('@');
expect(filename).not.toContain('#');
expect(filename).not.toContain('*');
expect(filename).not.toContain('!');
// Should still preserve dots and hyphens
expect(filename).toContain('.jpg');
done();
});
storage.getFilename(mockReq, unsafeFile, cb);
});
it('should handle very long filenames with actual crypto', (done) => {
const longFile = {
...mockFile,
originalname: 'a'.repeat(300) + '.jpg',
};
const cb = jest.fn((err, filename) => {
expect(err).toBeNull();
expect(filename.length).toBeLessThanOrEqual(255);
expect(filename).toMatch(/\.jpg$/); // Should still end with .jpg
// Should contain a hex suffix if truncated
if (filename.length === 255) {
expect(filename).toMatch(/-[a-f0-9]{6}\.jpg$/);
}
done();
});
storage.getFilename(mockReq, longFile, cb);
});
it('should generate unique file_id for each call', (done) => {
let firstFileId;
const firstCb = jest.fn((err, filename) => {
expect(err).toBeNull();
firstFileId = mockReq.file_id;
// Reset req for second call
delete mockReq.file_id;
const secondCb = jest.fn((err, filename) => {
expect(err).toBeNull();
expect(mockReq.file_id).toBeDefined();
expect(mockReq.file_id).not.toBe(firstFileId);
done();
});
storage.getFilename(mockReq, mockFile, secondCb);
});
storage.getFilename(mockReq, mockFile, firstCb);
});
});
});
describe('Import File Filter', () => {
it('should accept JSON files by mimetype', (done) => {
const jsonFile = {
...mockFile,
mimetype: 'application/json',
originalname: 'data.json',
};
const cb = jest.fn((err, result) => {
expect(err).toBeNull();
expect(result).toBe(true);
done();
});
importFileFilter(mockReq, jsonFile, cb);
});
it('should accept files with .json extension', (done) => {
const jsonFile = {
...mockFile,
mimetype: 'text/plain',
originalname: 'data.json',
};
const cb = jest.fn((err, result) => {
expect(err).toBeNull();
expect(result).toBe(true);
done();
});
importFileFilter(mockReq, jsonFile, cb);
});
it('should reject non-JSON files', (done) => {
const textFile = {
...mockFile,
mimetype: 'text/plain',
originalname: 'document.txt',
};
const cb = jest.fn((err, result) => {
expect(err).toBeInstanceOf(Error);
expect(err.message).toBe('Only JSON files are allowed');
expect(result).toBe(false);
done();
});
importFileFilter(mockReq, textFile, cb);
});
it('should handle files with uppercase .JSON extension', (done) => {
const jsonFile = {
...mockFile,
mimetype: 'text/plain',
originalname: 'DATA.JSON',
};
const cb = jest.fn((err, result) => {
expect(err).toBeNull();
expect(result).toBe(true);
done();
});
importFileFilter(mockReq, jsonFile, cb);
});
});
describe('File Filter with Real defaultFileConfig', () => {
it('should use real fileConfig.checkType for validation', async () => {
// Test with actual librechat-data-provider functions
const {
fileConfig,
imageMimeTypes,
applicationMimeTypes,
} = require('librechat-data-provider');
// Test that the real checkType function works with regex patterns
expect(fileConfig.checkType('image/jpeg', [imageMimeTypes])).toBe(true);
expect(fileConfig.checkType('video/mp4', [imageMimeTypes])).toBe(false);
expect(fileConfig.checkType('application/pdf', [applicationMimeTypes])).toBe(true);
expect(fileConfig.checkType('application/pdf', [])).toBe(false);
});
it('should handle audio files for speech-to-text endpoint with real config', async () => {
mockReq.originalUrl = '/api/speech/stt';
const multerInstance = await createMulterInstance();
expect(multerInstance).toBeDefined();
expect(typeof multerInstance.single).toBe('function');
});
it('should reject unsupported file types using real config', async () => {
// Mock defaultFileConfig for this specific test
const originalCheckType = require('librechat-data-provider').fileConfig.checkType;
const mockCheckType = jest.fn().mockReturnValue(false);
require('librechat-data-provider').fileConfig.checkType = mockCheckType;
try {
const multerInstance = await createMulterInstance();
expect(multerInstance).toBeDefined();
// Test the actual file filter behavior would reject unsupported files
expect(mockCheckType).toBeDefined();
} finally {
// Restore original function
require('librechat-data-provider').fileConfig.checkType = originalCheckType;
}
});
it('should use real mergeFileConfig function', async () => {
const { mergeFileConfig, mbToBytes } = require('librechat-data-provider');
// Test with actual merge function - note that it converts MB to bytes
const testConfig = {
serverFileSizeLimit: 5, // 5 MB
endpoints: {
custom: {
supportedMimeTypes: ['text/plain'],
},
},
};
const result = mergeFileConfig(testConfig);
// The function converts MB to bytes, so 5 MB becomes 5 * 1024 * 1024 bytes
expect(result.serverFileSizeLimit).toBe(mbToBytes(5));
expect(result.endpoints.custom.supportedMimeTypes).toBeDefined();
// Should still have the default endpoints
expect(result.endpoints.default).toBeDefined();
});
});
describe('createMulterInstance with Real Functions', () => {
it('should create a multer instance with correct configuration', async () => {
const multerInstance = await createMulterInstance();
expect(multerInstance).toBeDefined();
expect(typeof multerInstance.single).toBe('function');
expect(typeof multerInstance.array).toBe('function');
expect(typeof multerInstance.fields).toBe('function');
});
it('should use real config merging', async () => {
const { getCustomConfig } = require('~/server/services/Config');
const multerInstance = await createMulterInstance();
expect(getCustomConfig).toHaveBeenCalled();
expect(multerInstance).toBeDefined();
});
it('should create multer instance with expected interface', async () => {
const multerInstance = await createMulterInstance();
expect(multerInstance).toBeDefined();
expect(typeof multerInstance.single).toBe('function');
expect(typeof multerInstance.array).toBe('function');
expect(typeof multerInstance.fields).toBe('function');
});
});
describe('Real Crypto Integration', () => {
it('should use actual crypto.randomUUID()', (done) => {
// Spy on crypto.randomUUID to ensure it's called
const uuidSpy = jest.spyOn(crypto, 'randomUUID');
const cb = jest.fn((err, filename) => {
expect(err).toBeNull();
expect(uuidSpy).toHaveBeenCalled();
expect(mockReq.file_id).toMatch(
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i,
);
uuidSpy.mockRestore();
done();
});
storage.getFilename(mockReq, mockFile, cb);
});
it('should generate different UUIDs on subsequent calls', (done) => {
const uuids = [];
let callCount = 0;
const totalCalls = 5;
const cb = jest.fn((err, filename) => {
expect(err).toBeNull();
uuids.push(mockReq.file_id);
callCount++;
if (callCount === totalCalls) {
// Check that all UUIDs are unique
const uniqueUuids = new Set(uuids);
expect(uniqueUuids.size).toBe(totalCalls);
done();
} else {
// Reset for next call
delete mockReq.file_id;
storage.getFilename(mockReq, mockFile, cb);
}
});
// Start the chain
storage.getFilename(mockReq, mockFile, cb);
});
it('should generate cryptographically secure UUIDs', (done) => {
const generatedUuids = new Set();
let callCount = 0;
const totalCalls = 10;
const cb = jest.fn((err, filename) => {
expect(err).toBeNull();
// Verify UUID format and uniqueness
expect(mockReq.file_id).toMatch(
/^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i,
);
generatedUuids.add(mockReq.file_id);
callCount++;
if (callCount === totalCalls) {
// All UUIDs should be unique
expect(generatedUuids.size).toBe(totalCalls);
done();
} else {
// Reset for next call
delete mockReq.file_id;
storage.getFilename(mockReq, mockFile, cb);
}
});
// Start the chain
storage.getFilename(mockReq, mockFile, cb);
});
});
describe('Error Handling', () => {
it('should handle CVE-2024-28870: empty field name DoS vulnerability', async () => {
// Test for the CVE where empty field name could cause unhandled exception
const multerInstance = await createMulterInstance();
// Create a mock request with empty field name (the vulnerability scenario)
const mockReqWithEmptyField = {
...mockReq,
headers: {
'content-type': 'multipart/form-data',
},
};
const mockRes = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
end: jest.fn(),
};
// This should not crash or throw unhandled exceptions
const uploadMiddleware = multerInstance.single(''); // Empty field name
const mockNext = jest.fn((err) => {
// If there's an error, it should be handled gracefully, not crash
if (err) {
expect(err).toBeInstanceOf(Error);
// The error should be handled, not crash the process
}
});
// This should complete without crashing the process
expect(() => {
uploadMiddleware(mockReqWithEmptyField, mockRes, mockNext);
}).not.toThrow();
});
it('should handle file system errors when directory creation fails', (done) => {
// Test with a non-existent parent directory to simulate fs issues
const invalidPath = '/nonexistent/path/that/should/not/exist';
mockReq.app.locals.paths.uploads = invalidPath;
try {
// Call getDestination which should fail due to permission/path issues
storage.getDestination(mockReq, mockFile, (err, destination) => {
// If callback is reached, we didn't get the expected error
done(new Error('Expected mkdirSync to throw an error but callback was called'));
});
// If we get here without throwing, something unexpected happened
done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
} catch (error) {
// This is the expected behavior - mkdirSync throws synchronously for invalid paths
expect(error.code).toBe('EACCES');
done();
}
});
it('should handle malformed filenames with real sanitization', (done) => {
const malformedFile = {
...mockFile,
originalname: null, // This should be handled gracefully
};
const cb = jest.fn((err, filename) => {
// The function should handle this gracefully
expect(typeof err === 'object' || err === null).toBe(true);
done();
});
try {
storage.getFilename(mockReq, malformedFile, cb);
} catch (error) {
// If it throws, that's also acceptable behavior
done();
}
});
it('should handle edge cases in filename sanitization', (done) => {
const edgeCaseFiles = [
{ originalname: '', expected: /_/ },
{ originalname: '.hidden', expected: /^_\.hidden/ },
{ originalname: '../../../etc/passwd', expected: /passwd/ },
{ originalname: 'file\x00name.txt', expected: /file_name\.txt/ },
];
let testCount = 0;
const testNextFile = (fileData) => {
const fileToTest = { ...mockFile, originalname: fileData.originalname };
const cb = jest.fn((err, filename) => {
expect(err).toBeNull();
expect(filename).toMatch(fileData.expected);
testCount++;
if (testCount === edgeCaseFiles.length) {
done();
} else {
testNextFile(edgeCaseFiles[testCount]);
}
});
storage.getFilename(mockReq, fileToTest, cb);
};
testNextFile(edgeCaseFiles[0]);
});
});
describe('Real Configuration Testing', () => {
it('should handle missing custom config gracefully with real mergeFileConfig', async () => {
const { getCustomConfig } = require('~/server/services/Config');
// Mock getCustomConfig to return undefined
getCustomConfig.mockResolvedValueOnce(undefined);
const multerInstance = await createMulterInstance();
expect(multerInstance).toBeDefined();
expect(typeof multerInstance.single).toBe('function');
});
it('should properly integrate real fileConfig with custom endpoints', async () => {
const { getCustomConfig } = require('~/server/services/Config');
// Mock a custom config with additional endpoints
getCustomConfig.mockResolvedValueOnce({
fileConfig: {
endpoints: {
anthropic: {
supportedMimeTypes: ['text/plain', 'image/png'],
},
},
serverFileSizeLimit: 20, // 20 MB
},
});
const multerInstance = await createMulterInstance();
expect(multerInstance).toBeDefined();
// Verify that getCustomConfig was called (we can't spy on the actual merge function easily)
expect(getCustomConfig).toHaveBeenCalled();
});
});
});

View file

@ -4,6 +4,7 @@ const tokenizer = require('./tokenizer');
const endpoints = require('./endpoints');
const staticRoute = require('./static');
const messages = require('./messages');
const memories = require('./memories');
const presets = require('./presets');
const prompts = require('./prompts');
const balance = require('./balance');
@ -26,6 +27,7 @@ const edit = require('./edit');
const keys = require('./keys');
const user = require('./user');
const ask = require('./ask');
const mcp = require('./mcp');
module.exports = {
ask,
@ -51,9 +53,11 @@ module.exports = {
presets,
balance,
messages,
memories,
endpoints,
tokenizer,
assistants,
categories,
staticRoute,
mcp,
};

205
api/server/routes/mcp.js Normal file
View file

@ -0,0 +1,205 @@
const { Router } = require('express');
const { MCPOAuthHandler } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { requireJwtAuth } = require('~/server/middleware');
const { getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache');
const router = Router();
/**
* Initiate OAuth flow
* This endpoint is called when the user clicks the auth link in the UI
*/
router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
try {
const { serverName } = req.params;
const { userId, flowId } = req.query;
const user = req.user;
// Verify the userId matches the authenticated user
if (userId !== user.id) {
return res.status(403).json({ error: 'User mismatch' });
}
logger.debug('[MCP OAuth] Initiate request', { serverName, userId, flowId });
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
/** Flow state to retrieve OAuth config */
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (!flowState) {
logger.error('[MCP OAuth] Flow state not found', { flowId });
return res.status(404).json({ error: 'Flow not found' });
}
const { serverUrl, oauth: oauthConfig } = flowState.metadata || {};
if (!serverUrl || !oauthConfig) {
logger.error('[MCP OAuth] Missing server URL or OAuth config in flow state');
return res.status(400).json({ error: 'Invalid flow state' });
}
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
serverName,
serverUrl,
userId,
oauthConfig,
);
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
// Redirect user to the authorization URL
res.redirect(authorizationUrl);
} catch (error) {
logger.error('[MCP OAuth] Failed to initiate OAuth', error);
res.status(500).json({ error: 'Failed to initiate OAuth' });
}
});
/**
* OAuth callback handler
* This handles the OAuth callback after the user has authorized the application
*/
router.get('/:serverName/oauth/callback', async (req, res) => {
try {
const { serverName } = req.params;
const { code, state, error: oauthError } = req.query;
logger.debug('[MCP OAuth] Callback received', {
serverName,
code: code ? 'present' : 'missing',
state,
error: oauthError,
});
if (oauthError) {
logger.error('[MCP OAuth] OAuth error received', { error: oauthError });
return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`);
}
if (!code || typeof code !== 'string') {
logger.error('[MCP OAuth] Missing or invalid code');
return res.redirect('/oauth/error?error=missing_code');
}
if (!state || typeof state !== 'string') {
logger.error('[MCP OAuth] Missing or invalid state');
return res.redirect('/oauth/error?error=missing_state');
}
// Extract flow ID from state
const flowId = state;
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
if (!flowState) {
logger.error('[MCP OAuth] Flow state not found for flowId:', flowId);
return res.redirect('/oauth/error?error=invalid_state');
}
logger.debug('[MCP OAuth] Flow state details', {
serverName: flowState.serverName,
userId: flowState.userId,
hasMetadata: !!flowState.metadata,
hasClientInfo: !!flowState.clientInfo,
hasCodeVerifier: !!flowState.codeVerifier,
});
// Complete the OAuth flow
logger.debug('[MCP OAuth] Completing OAuth flow');
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
// For system-level OAuth, we need to store the tokens and retry the connection
if (flowState.userId === 'system') {
logger.debug(`[MCP OAuth] System-level OAuth completed for ${serverName}`);
}
/** ID of the flow that the tool/connection is waiting for */
const toolFlowId = flowState.metadata?.toolFlowId;
if (toolFlowId) {
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
}
/** Redirect to success page with flowId and serverName */
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
res.redirect(redirectUrl);
} catch (error) {
logger.error('[MCP OAuth] OAuth callback error', error);
res.redirect('/oauth/error?error=callback_failed');
}
});
/**
* Get OAuth tokens for a completed flow
* This is primarily for user-level OAuth flows
*/
router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => {
try {
const { flowId } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
// Allow system flows or user-owned flows
if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) {
return res.status(403).json({ error: 'Access denied' });
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (!flowState) {
return res.status(404).json({ error: 'Flow not found' });
}
if (flowState.status !== 'COMPLETED') {
return res.status(400).json({ error: 'Flow not completed' });
}
res.json({ tokens: flowState.result });
} catch (error) {
logger.error('[MCP OAuth] Failed to get tokens', error);
res.status(500).json({ error: 'Failed to get tokens' });
}
});
/**
* Check OAuth flow status
* This endpoint can be used to poll the status of an OAuth flow
*/
router.get('/oauth/status/:flowId', async (req, res) => {
try {
const { flowId } = req.params;
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (!flowState) {
return res.status(404).json({ error: 'Flow not found' });
}
res.json({
status: flowState.status,
completed: flowState.status === 'COMPLETED',
failed: flowState.status === 'FAILED',
error: flowState.error,
});
} catch (error) {
logger.error('[MCP OAuth] Failed to get flow status', error);
res.status(500).json({ error: 'Failed to get flow status' });
}
});
module.exports = router;

View file

@ -0,0 +1,231 @@
const express = require('express');
const { Tokenizer } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const {
getAllUserMemories,
toggleUserMemories,
createMemory,
setMemory,
deleteMemory,
} = require('~/models');
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const router = express.Router();
const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.READ,
]);
const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.CREATE,
]);
const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.UPDATE,
]);
const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.UPDATE,
]);
const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.OPT_OUT,
]);
router.use(requireJwtAuth);
/**
* GET /memories
* Returns all memories for the authenticated user, sorted by updated_at (newest first).
* Also includes memory usage percentage based on token limit.
*/
router.get('/', checkMemoryRead, async (req, res) => {
try {
const memories = await getAllUserMemories(req.user.id);
const sortedMemories = memories.sort(
(a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime(),
);
const totalTokens = memories.reduce((sum, memory) => {
return sum + (memory.tokenCount || 0);
}, 0);
const memoryConfig = req.app.locals?.memory;
const tokenLimit = memoryConfig?.tokenLimit;
let usagePercentage = null;
if (tokenLimit && tokenLimit > 0) {
usagePercentage = Math.min(100, Math.round((totalTokens / tokenLimit) * 100));
}
res.json({
memories: sortedMemories,
totalTokens,
tokenLimit: tokenLimit || null,
usagePercentage,
});
} catch (error) {
res.status(500).json({ error: error.message });
}
});
/**
* POST /memories
* Creates a new memory entry for the authenticated user.
* Body: { key: string, value: string }
* Returns 201 and { created: true, memory: <createdDoc> } when successful.
*/
router.post('/', checkMemoryCreate, async (req, res) => {
const { key, value } = req.body;
if (typeof key !== 'string' || key.trim() === '') {
return res.status(400).json({ error: 'Key is required and must be a non-empty string.' });
}
if (typeof value !== 'string' || value.trim() === '') {
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
}
try {
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
const memories = await getAllUserMemories(req.user.id);
// Check token limit
const memoryConfig = req.app.locals?.memory;
const tokenLimit = memoryConfig?.tokenLimit;
if (tokenLimit) {
const currentTotalTokens = memories.reduce(
(sum, memory) => sum + (memory.tokenCount || 0),
0,
);
if (currentTotalTokens + tokenCount > tokenLimit) {
return res.status(400).json({
error: `Adding this memory would exceed the token limit of ${tokenLimit}. Current usage: ${currentTotalTokens} tokens.`,
});
}
}
const result = await createMemory({
userId: req.user.id,
key: key.trim(),
value: value.trim(),
tokenCount,
});
if (!result.ok) {
return res.status(500).json({ error: 'Failed to create memory.' });
}
const updatedMemories = await getAllUserMemories(req.user.id);
const newMemory = updatedMemories.find((m) => m.key === key.trim());
res.status(201).json({ created: true, memory: newMemory });
} catch (error) {
if (error.message && error.message.includes('already exists')) {
return res.status(409).json({ error: 'Memory with this key already exists.' });
}
res.status(500).json({ error: error.message });
}
});
/**
* PATCH /memories/preferences
* Updates the user's memory preferences (e.g., enabling/disabling memories).
* Body: { memories: boolean }
* Returns 200 and { updated: true, preferences: { memories: boolean } } when successful.
*/
router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
const { memories } = req.body;
if (typeof memories !== 'boolean') {
return res.status(400).json({ error: 'memories must be a boolean value.' });
}
try {
const updatedUser = await toggleUserMemories(req.user.id, memories);
if (!updatedUser) {
return res.status(404).json({ error: 'User not found.' });
}
res.json({
updated: true,
preferences: {
memories: updatedUser.personalization?.memories ?? true,
},
});
} catch (error) {
res.status(500).json({ error: error.message });
}
});
/**
* PATCH /memories/:key
* Updates the value of an existing memory entry for the authenticated user.
* Body: { value: string }
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
*/
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
const { key } = req.params;
const { value } = req.body || {};
if (typeof value !== 'string' || value.trim() === '') {
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
}
try {
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
const memories = await getAllUserMemories(req.user.id);
const existingMemory = memories.find((m) => m.key === key);
if (!existingMemory) {
return res.status(404).json({ error: 'Memory not found.' });
}
const result = await setMemory({
userId: req.user.id,
key,
value,
tokenCount,
});
if (!result.ok) {
return res.status(500).json({ error: 'Failed to update memory.' });
}
const updatedMemories = await getAllUserMemories(req.user.id);
const updatedMemory = updatedMemories.find((m) => m.key === key);
res.json({ updated: true, memory: updatedMemory });
} catch (error) {
res.status(500).json({ error: error.message });
}
});
/**
* DELETE /memories/:key
* Deletes a memory entry for the authenticated user.
* Returns 200 and { deleted: true } when successful.
*/
router.delete('/:key', checkMemoryDelete, async (req, res) => {
const { key } = req.params;
try {
const result = await deleteMemory({ userId: req.user.id, key });
if (!result.ok) {
return res.status(404).json({ error: 'Memory not found.' });
}
res.json({ deleted: true });
} catch (error) {
res.status(500).json({ error: error.message });
}
});
module.exports = router;

View file

@ -47,7 +47,9 @@ const oauthHandler = async (req, res) => {
router.get('/error', (req, res) => {
// A single error message is pushed by passport when authentication fails.
logger.error('Error in OAuth authentication:', { message: req.session.messages.pop() });
logger.error('Error in OAuth authentication:', {
message: req.session?.messages?.pop() || 'Unknown error',
});
// Redirect to login page with auth_failed parameter to prevent infinite redirect loops
res.redirect(`${domains.client}/login?redirect=false`);

View file

@ -1,6 +1,7 @@
const express = require('express');
const {
promptPermissionsSchema,
memoryPermissionsSchema,
agentPermissionsSchema,
PermissionTypes,
roleDefaults,
@ -118,4 +119,43 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => {
}
});
/**
* PUT /api/roles/:roleName/memories
* Update memory permissions for a specific role
*/
router.put('/:roleName/memories', checkAdmin, async (req, res) => {
const { roleName: _r } = req.params;
// TODO: TEMP, use a better parsing for roleName
const roleName = _r.toUpperCase();
/** @type {TRole['permissions']['MEMORIES']} */
const updates = req.body;
try {
const parsedUpdates = memoryPermissionsSchema.partial().parse(updates);
const role = await getRoleByName(roleName);
if (!role) {
return res.status(404).send({ message: 'Role not found' });
}
const currentPermissions =
role.permissions?.[PermissionTypes.MEMORIES] || role[PermissionTypes.MEMORIES] || {};
const mergedUpdates = {
permissions: {
...role.permissions,
[PermissionTypes.MEMORIES]: {
...currentPermissions,
...parsedUpdates,
},
},
};
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
res.status(200).send(updatedRole);
} catch (error) {
return res.status(400).send({ message: 'Invalid memory permissions.', error: error.errors });
}
});
module.exports = router;

View file

@ -1,15 +1,15 @@
const express = require('express');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
getSharedLink,
getSharedMessages,
createSharedLink,
updateSharedLink,
getSharedLinks,
deleteSharedLink,
} = require('~/models/Share');
getSharedLinks,
getSharedLink,
} = require('~/models');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { isEnabled } = require('~/server/utils');
const router = express.Router();
/**
@ -35,6 +35,7 @@ if (allowSharedLinks) {
res.status(404).end();
}
} catch (error) {
logger.error('Error getting shared messages:', error);
res.status(500).json({ message: 'Error getting shared messages' });
}
},
@ -54,9 +55,7 @@ router.get('/', requireJwtAuth, async (req, res) => {
sortDirection: ['asc', 'desc'].includes(req.query.sortDirection)
? req.query.sortDirection
: 'desc',
search: req.query.search
? decodeURIComponent(req.query.search.trim())
: undefined,
search: req.query.search ? decodeURIComponent(req.query.search.trim()) : undefined,
};
const result = await getSharedLinks(
@ -75,7 +74,7 @@ router.get('/', requireJwtAuth, async (req, res) => {
hasNextPage: result.hasNextPage,
});
} catch (error) {
console.error('Error getting shared links:', error);
logger.error('Error getting shared links:', error);
res.status(500).json({
message: 'Error getting shared links',
error: error.message,
@ -93,6 +92,7 @@ router.get('/link/:conversationId', requireJwtAuth, async (req, res) => {
conversationId: req.params.conversationId,
});
} catch (error) {
logger.error('Error getting shared link:', error);
res.status(500).json({ message: 'Error getting shared link' });
}
});
@ -106,6 +106,7 @@ router.post('/:conversationId', requireJwtAuth, async (req, res) => {
res.status(404).end();
}
} catch (error) {
logger.error('Error creating shared link:', error);
res.status(500).json({ message: 'Error creating shared link' });
}
});
@ -119,6 +120,7 @@ router.patch('/:shareId', requireJwtAuth, async (req, res) => {
res.status(404).end();
}
} catch (error) {
logger.error('Error updating shared link:', error);
res.status(500).json({ message: 'Error updating shared link' });
}
});
@ -133,7 +135,8 @@ router.delete('/:shareId', requireJwtAuth, async (req, res) => {
return res.status(200).json(result);
} catch (error) {
return res.status(400).json({ message: error.message });
logger.error('Error deleting shared link:', error);
return res.status(400).json({ message: 'Error deleting shared link' });
}
});