mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-10 20:48:54 +01:00
Merge branch 'main' into refactor/package-auth
This commit is contained in:
commit
02b9c9d447
340 changed files with 18559 additions and 14872 deletions
|
|
@ -10,6 +10,7 @@ const {
|
|||
validateVisionModel,
|
||||
} = require('librechat-data-provider');
|
||||
const { SplitStreamHandler: _Handler } = require('@librechat/agents');
|
||||
const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api');
|
||||
const {
|
||||
truncateText,
|
||||
formatMessage,
|
||||
|
|
@ -26,8 +27,6 @@ const {
|
|||
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { createFetch, createStreamEventHandlers } = require('./generators');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ const { Keyv } = require('keyv');
|
|||
const crypto = require('crypto');
|
||||
const { CohereClient } = require('cohere-ai');
|
||||
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
||||
const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
ImageDetail,
|
||||
|
|
@ -10,9 +11,9 @@ const {
|
|||
CohereConstants,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils');
|
||||
const { createContextHandlers } = require('./prompts');
|
||||
const { createCoherePayload } = require('./llm');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { google } = require('googleapis');
|
||||
const { Tokenizer } = require('@librechat/api');
|
||||
const { concat } = require('@langchain/core/utils/stream');
|
||||
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
||||
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
||||
|
|
@ -19,7 +20,6 @@ const {
|
|||
} = require('librechat-data-provider');
|
||||
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
|
|
@ -34,7 +34,8 @@ const BaseClient = require('./BaseClient');
|
|||
|
||||
const loc = process.env.GOOGLE_LOC || 'us-central1';
|
||||
const publisher = 'google';
|
||||
const endpointPrefix = `${loc}-aiplatform.googleapis.com`;
|
||||
const endpointPrefix =
|
||||
loc === 'global' ? 'aiplatform.googleapis.com' : `${loc}-aiplatform.googleapis.com`;
|
||||
|
||||
const settings = endpointSettings[EModelEndpoint.google];
|
||||
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
const { z } = require('zod');
|
||||
const axios = require('axios');
|
||||
const { Ollama } = require('ollama');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { deriveBaseURL, logAxiosError } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const { deriveBaseURL } = require('~/utils');
|
||||
|
||||
const ollamaPayloadSchema = z.object({
|
||||
mirostat: z.number().optional(),
|
||||
|
|
@ -67,7 +68,7 @@ class OllamaClient {
|
|||
return models;
|
||||
} catch (error) {
|
||||
const logMessage =
|
||||
'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).';
|
||||
"Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn't start with `ollama` (case-insensitive).";
|
||||
logAxiosError({ message: logMessage, error });
|
||||
return [];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,14 @@
|
|||
const { OllamaClient } = require('./OllamaClient');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { SplitStreamHandler, CustomOpenAIClient: OpenAI } = require('@librechat/agents');
|
||||
const {
|
||||
isEnabled,
|
||||
Tokenizer,
|
||||
createFetch,
|
||||
constructAzureURL,
|
||||
genAzureChatCompletion,
|
||||
createStreamEventHandlers,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Constants,
|
||||
ImageDetail,
|
||||
|
|
@ -16,13 +24,6 @@ const {
|
|||
validateVisionModel,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
extractBaseURL,
|
||||
constructAzureURL,
|
||||
getModelMaxTokens,
|
||||
genAzureChatCompletion,
|
||||
getModelMaxOutputTokens,
|
||||
} = require('~/utils');
|
||||
const {
|
||||
truncateText,
|
||||
formatMessage,
|
||||
|
|
@ -30,10 +31,9 @@ const {
|
|||
titleInstruction,
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { createFetch, createStreamEventHandlers } = require('./generators');
|
||||
const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { addSpaceIfNeeded, sleep } = require('~/server/utils');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const { createLLM, RunManager } = require('./llm');
|
||||
|
|
|
|||
|
|
@ -1,71 +0,0 @@
|
|||
const fetch = require('node-fetch');
|
||||
const { GraphEvents } = require('@librechat/agents');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
const { sleep } = require('~/server/utils');
|
||||
|
||||
/**
|
||||
* Makes a function to make HTTP request and logs the process.
|
||||
* @param {Object} params
|
||||
* @param {boolean} [params.directEndpoint] - Whether to use a direct endpoint.
|
||||
* @param {string} [params.reverseProxyUrl] - The reverse proxy URL to use for the request.
|
||||
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
function createFetch({ directEndpoint = false, reverseProxyUrl = '' }) {
|
||||
/**
|
||||
* Makes an HTTP request and logs the process.
|
||||
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
|
||||
* @param {RequestInit} [init] - Optional init options for the request.
|
||||
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
return async (_url, init) => {
|
||||
let url = _url;
|
||||
if (directEndpoint) {
|
||||
url = reverseProxyUrl;
|
||||
}
|
||||
logger.debug(`Making request to ${url}`);
|
||||
if (typeof Bun !== 'undefined') {
|
||||
return await fetch(url, init);
|
||||
}
|
||||
return await fetch(url, init);
|
||||
};
|
||||
}
|
||||
|
||||
// Add this at the module level outside the class
|
||||
/**
|
||||
* Creates event handlers for stream events that don't capture client references
|
||||
* @param {Object} res - The response object to send events to
|
||||
* @returns {Object} Object containing handler functions
|
||||
*/
|
||||
function createStreamEventHandlers(res) {
|
||||
return {
|
||||
[GraphEvents.ON_RUN_STEP]: (event) => {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
[GraphEvents.ON_MESSAGE_DELTA]: (event) => {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
[GraphEvents.ON_REASONING_DELTA]: (event) => {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function createHandleLLMNewToken(streamRate) {
|
||||
return async () => {
|
||||
if (streamRate) {
|
||||
await sleep(streamRate);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
createFetch,
|
||||
createHandleLLMNewToken,
|
||||
createStreamEventHandlers,
|
||||
};
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
const { ChatOpenAI } = require('@langchain/openai');
|
||||
const { sanitizeModelName, constructAzureURL } = require('~/utils');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { isEnabled, sanitizeModelName, constructAzureURL } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Creates a new instance of a language model (LLM) for chat interactions.
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ jest.mock('~/models', () => ({
|
|||
const { getConvo, saveConvo } = require('~/models');
|
||||
|
||||
jest.mock('@librechat/agents', () => {
|
||||
const { Providers } = jest.requireActual('@librechat/agents');
|
||||
return {
|
||||
Providers,
|
||||
ChatOpenAI: jest.fn().mockImplementation(() => {
|
||||
return {};
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -1,184 +0,0 @@
|
|||
require('dotenv').config();
|
||||
const fs = require('fs');
|
||||
const { z } = require('zod');
|
||||
const path = require('path');
|
||||
const yaml = require('js-yaml');
|
||||
const { createOpenAPIChain } = require('langchain/chains');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { ChatPromptTemplate, HumanMessagePromptTemplate } = require('@langchain/core/prompts');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
function addLinePrefix(text, prefix = '// ') {
|
||||
return text
|
||||
.split('\n')
|
||||
.map((line) => prefix + line)
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
function createPrompt(name, functions) {
|
||||
const prefix = `// The ${name} tool has the following functions. Determine the desired or most optimal function for the user's query:`;
|
||||
const functionDescriptions = functions
|
||||
.map((func) => `// - ${func.name}: ${func.description}`)
|
||||
.join('\n');
|
||||
return `${prefix}\n${functionDescriptions}
|
||||
// You are an expert manager and scrum master. You must provide a detailed intent to better execute the function.
|
||||
// Always format as such: {{"func": "function_name", "intent": "intent and expected result"}}`;
|
||||
}
|
||||
|
||||
const AuthBearer = z
|
||||
.object({
|
||||
type: z.string().includes('service_http'),
|
||||
authorization_type: z.string().includes('bearer'),
|
||||
verification_tokens: z.object({
|
||||
openai: z.string(),
|
||||
}),
|
||||
})
|
||||
.catch(() => false);
|
||||
|
||||
const AuthDefinition = z
|
||||
.object({
|
||||
type: z.string(),
|
||||
authorization_type: z.string(),
|
||||
verification_tokens: z.object({
|
||||
openai: z.string(),
|
||||
}),
|
||||
})
|
||||
.catch(() => false);
|
||||
|
||||
async function readSpecFile(filePath) {
|
||||
try {
|
||||
const fileContents = await fs.promises.readFile(filePath, 'utf8');
|
||||
if (path.extname(filePath) === '.json') {
|
||||
return JSON.parse(fileContents);
|
||||
}
|
||||
return yaml.load(fileContents);
|
||||
} catch (e) {
|
||||
logger.error('[readSpecFile] error', e);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function getSpec(url) {
|
||||
const RegularUrl = z
|
||||
.string()
|
||||
.url()
|
||||
.catch(() => false);
|
||||
|
||||
if (RegularUrl.parse(url) && path.extname(url) === '.json') {
|
||||
const response = await fetch(url);
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
const ValidSpecPath = z
|
||||
.string()
|
||||
.url()
|
||||
.catch(async () => {
|
||||
const spec = path.join(__dirname, '..', '.well-known', 'openapi', url);
|
||||
if (!fs.existsSync(spec)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return await readSpecFile(spec);
|
||||
});
|
||||
|
||||
return ValidSpecPath.parse(url);
|
||||
}
|
||||
|
||||
async function createOpenAPIPlugin({ data, llm, user, message, memory, signal }) {
|
||||
let spec;
|
||||
try {
|
||||
spec = await getSpec(data.api.url);
|
||||
} catch (error) {
|
||||
logger.error('[createOpenAPIPlugin] getSpec error', error);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!spec) {
|
||||
logger.warn('[createOpenAPIPlugin] No spec found');
|
||||
return null;
|
||||
}
|
||||
|
||||
const headers = {};
|
||||
const { auth, name_for_model, description_for_model, description_for_human } = data;
|
||||
if (auth && AuthDefinition.parse(auth)) {
|
||||
logger.debug('[createOpenAPIPlugin] auth detected', auth);
|
||||
const { openai } = auth.verification_tokens;
|
||||
if (AuthBearer.parse(auth)) {
|
||||
headers.authorization = `Bearer ${openai}`;
|
||||
logger.debug('[createOpenAPIPlugin] added auth bearer', headers);
|
||||
}
|
||||
}
|
||||
|
||||
const chainOptions = { llm };
|
||||
|
||||
if (data.headers && data.headers['librechat_user_id']) {
|
||||
logger.debug('[createOpenAPIPlugin] id detected', headers);
|
||||
headers[data.headers['librechat_user_id']] = user;
|
||||
}
|
||||
|
||||
if (Object.keys(headers).length > 0) {
|
||||
logger.debug('[createOpenAPIPlugin] headers detected', headers);
|
||||
chainOptions.headers = headers;
|
||||
}
|
||||
|
||||
if (data.params) {
|
||||
logger.debug('[createOpenAPIPlugin] params detected', data.params);
|
||||
chainOptions.params = data.params;
|
||||
}
|
||||
|
||||
let history = '';
|
||||
if (memory) {
|
||||
logger.debug('[createOpenAPIPlugin] openAPI chain: memory detected', memory);
|
||||
const { history: chat_history } = await memory.loadMemoryVariables({});
|
||||
history = chat_history?.length > 0 ? `\n\n## Chat History:\n${chat_history}\n` : '';
|
||||
}
|
||||
|
||||
chainOptions.prompt = ChatPromptTemplate.fromMessages([
|
||||
HumanMessagePromptTemplate.fromTemplate(
|
||||
`# Use the provided API's to respond to this query:\n\n{query}\n\n## Instructions:\n${addLinePrefix(
|
||||
description_for_model,
|
||||
)}${history}`,
|
||||
),
|
||||
]);
|
||||
|
||||
const chain = await createOpenAPIChain(spec, chainOptions);
|
||||
|
||||
const { functions } = chain.chains[0].lc_kwargs.llmKwargs;
|
||||
|
||||
return new DynamicStructuredTool({
|
||||
name: name_for_model,
|
||||
description_for_model: `${addLinePrefix(description_for_human)}${createPrompt(
|
||||
name_for_model,
|
||||
functions,
|
||||
)}`,
|
||||
description: `${description_for_human}`,
|
||||
schema: z.object({
|
||||
func: z
|
||||
.string()
|
||||
.describe(
|
||||
`The function to invoke. The functions available are: ${functions
|
||||
.map((func) => func.name)
|
||||
.join(', ')}`,
|
||||
),
|
||||
intent: z
|
||||
.string()
|
||||
.describe('Describe your intent with the function and your expected result'),
|
||||
}),
|
||||
func: async ({ func = '', intent = '' }) => {
|
||||
const filteredFunctions = functions.filter((f) => f.name === func);
|
||||
chain.chains[0].lc_kwargs.llmKwargs.functions = filteredFunctions;
|
||||
const query = `${message}${func?.length > 0 ? `\n// Intent: ${intent}` : ''}`;
|
||||
const result = await chain.call({
|
||||
query,
|
||||
signal,
|
||||
});
|
||||
return result.response;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getSpec,
|
||||
readSpecFile,
|
||||
createOpenAPIPlugin,
|
||||
};
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
const fs = require('fs');
|
||||
const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin');
|
||||
|
||||
global.fetch = jest.fn().mockImplementationOnce(() => {
|
||||
return new Promise((resolve) => {
|
||||
resolve({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ key: 'value' }),
|
||||
});
|
||||
});
|
||||
});
|
||||
jest.mock('fs', () => ({
|
||||
promises: {
|
||||
readFile: jest.fn(),
|
||||
},
|
||||
existsSync: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('readSpecFile', () => {
|
||||
it('reads JSON file correctly', async () => {
|
||||
fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' }));
|
||||
const result = await readSpecFile('test.json');
|
||||
expect(result).toEqual({ test: 'value' });
|
||||
});
|
||||
|
||||
it('reads YAML file correctly', async () => {
|
||||
fs.promises.readFile.mockResolvedValue('test: value');
|
||||
const result = await readSpecFile('test.yaml');
|
||||
expect(result).toEqual({ test: 'value' });
|
||||
});
|
||||
|
||||
it('handles error correctly', async () => {
|
||||
fs.promises.readFile.mockRejectedValue(new Error('test error'));
|
||||
const result = await readSpecFile('test.json');
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSpec', () => {
|
||||
it('fetches spec from url correctly', async () => {
|
||||
const parsedJson = await getSpec('https://www.instacart.com/.well-known/ai-plugin.json');
|
||||
const isObject = typeof parsedJson === 'object';
|
||||
expect(isObject).toEqual(true);
|
||||
});
|
||||
|
||||
it('reads spec from file correctly', async () => {
|
||||
fs.existsSync.mockReturnValue(true);
|
||||
fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' }));
|
||||
const result = await getSpec('test.json');
|
||||
expect(result).toEqual({ test: 'value' });
|
||||
});
|
||||
|
||||
it('returns false when file does not exist', async () => {
|
||||
fs.existsSync.mockReturnValue(false);
|
||||
const result = await getSpec('test.json');
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createOpenAPIPlugin', () => {
|
||||
it('returns null when getSpec throws an error', async () => {
|
||||
const result = await createOpenAPIPlugin({ data: { api: { url: 'invalid' } } });
|
||||
expect(result).toBe(null);
|
||||
});
|
||||
|
||||
it('returns null when no spec is found', async () => {
|
||||
const result = await createOpenAPIPlugin({});
|
||||
expect(result).toBe(null);
|
||||
});
|
||||
|
||||
// Add more tests here for different scenarios
|
||||
});
|
||||
|
|
@ -8,10 +8,10 @@ const { HttpsProxyAgent } = require('https-proxy-agent');
|
|||
const { FileContext, ContentTypes } = require('librechat-data-provider');
|
||||
const { getImageBasename } = require('~/server/services/Files/images');
|
||||
const extractBaseURL = require('~/utils/extractBaseURL');
|
||||
const { logger } = require('~/config');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
const displayMessage =
|
||||
'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.';
|
||||
"DALL-E displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.";
|
||||
class DALLE3 extends Tool {
|
||||
constructor(fields = {}) {
|
||||
super();
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ const { v4 } = require('uuid');
|
|||
const OpenAI = require('openai');
|
||||
const FormData = require('form-data');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { logAxiosError, extractBaseURL } = require('~/utils');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/** Default descriptions for image generation tool */
|
||||
const DEFAULT_IMAGE_GEN_DESCRIPTION = `
|
||||
|
|
|
|||
|
|
@ -1,10 +1,29 @@
|
|||
const OpenAI = require('openai');
|
||||
const DALLE3 = require('../DALLE3');
|
||||
|
||||
const { logger } = require('~/config');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
jest.mock('openai');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
return {
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('tiktoken', () => {
|
||||
return {
|
||||
encoding_for_model: jest.fn().mockReturnValue({
|
||||
encode: jest.fn(),
|
||||
decode: jest.fn(),
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
jest.mock('~/server/services/Files/images', () => ({
|
||||
|
|
@ -37,6 +56,11 @@ jest.mock('fs', () => {
|
|||
return {
|
||||
existsSync: jest.fn(),
|
||||
mkdirSync: jest.fn(),
|
||||
promises: {
|
||||
writeFile: jest.fn(),
|
||||
readFile: jest.fn(),
|
||||
unlink: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
|||
query: z
|
||||
.string()
|
||||
.describe(
|
||||
'A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you\'re looking for. The query will be used for semantic similarity matching against the file contents.',
|
||||
"A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you're looking for. The query will be used for semantic similarity matching against the file contents.",
|
||||
),
|
||||
}),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
const axios = require('axios');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time, CacheKeys } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager } = require('librechat-mcp');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const logger = require('./winston');
|
||||
|
||||
global.EventSource = EventSource;
|
||||
|
|
@ -37,60 +36,8 @@ function getFlowStateManager(flowsCache) {
|
|||
return flowManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends message data in Server Sent Events format.
|
||||
* @param {ServerResponse} res - The server response.
|
||||
* @param {{ data: string | Record<string, unknown>, event?: string }} event - The message event.
|
||||
* @param {string} event.event - The type of event.
|
||||
* @param {string} event.data - The message to be sent.
|
||||
*/
|
||||
const sendEvent = (res, event) => {
|
||||
if (typeof event.data === 'string' && event.data.length === 0) {
|
||||
return;
|
||||
}
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates and configures an Axios instance with optional proxy settings.
|
||||
*
|
||||
* @typedef {import('axios').AxiosInstance} AxiosInstance
|
||||
* @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig
|
||||
*
|
||||
* @returns {AxiosInstance} A configured Axios instance
|
||||
* @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL
|
||||
*/
|
||||
function createAxiosInstance() {
|
||||
const instance = axios.create();
|
||||
|
||||
if (process.env.proxy) {
|
||||
try {
|
||||
const url = new URL(process.env.proxy);
|
||||
|
||||
/** @type {AxiosProxyConfig} */
|
||||
const proxyConfig = {
|
||||
host: url.hostname.replace(/^\[|\]$/g, ''),
|
||||
protocol: url.protocol.replace(':', ''),
|
||||
};
|
||||
|
||||
if (url.port) {
|
||||
proxyConfig.port = parseInt(url.port, 10);
|
||||
}
|
||||
|
||||
instance.defaults.proxy = proxyConfig;
|
||||
} catch (error) {
|
||||
console.error('Error parsing proxy URL:', error);
|
||||
throw new Error(`Invalid proxy URL: ${process.env.proxy}`);
|
||||
}
|
||||
}
|
||||
|
||||
return instance;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
logger,
|
||||
sendEvent,
|
||||
getMCPManager,
|
||||
createAxiosInstance,
|
||||
getFlowStateManager,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,126 +0,0 @@
|
|||
const axios = require('axios');
|
||||
const { createAxiosInstance } = require('./index');
|
||||
|
||||
// Mock axios
|
||||
jest.mock('axios', () => ({
|
||||
interceptors: {
|
||||
request: { use: jest.fn(), eject: jest.fn() },
|
||||
response: { use: jest.fn(), eject: jest.fn() },
|
||||
},
|
||||
create: jest.fn().mockReturnValue({
|
||||
defaults: {
|
||||
proxy: null,
|
||||
},
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
}),
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
reset: jest.fn().mockImplementation(function () {
|
||||
this.get.mockClear();
|
||||
this.post.mockClear();
|
||||
this.put.mockClear();
|
||||
this.delete.mockClear();
|
||||
this.create.mockClear();
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('createAxiosInstance', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
// Create a clean copy of process.env
|
||||
process.env = { ...originalEnv };
|
||||
// Default: no proxy
|
||||
delete process.env.proxy;
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
// Restore original process.env
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
test('creates an axios instance without proxy when no proxy env is set', () => {
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toBeNull();
|
||||
});
|
||||
|
||||
test('configures proxy correctly with hostname and protocol', () => {
|
||||
process.env.proxy = 'http://example.com';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'example.com',
|
||||
protocol: 'http',
|
||||
});
|
||||
});
|
||||
|
||||
test('configures proxy correctly with hostname, protocol and port', () => {
|
||||
process.env.proxy = 'https://proxy.example.com:8080';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'https',
|
||||
port: 8080,
|
||||
});
|
||||
});
|
||||
|
||||
test('handles proxy URLs with authentication', () => {
|
||||
process.env.proxy = 'http://user:pass@proxy.example.com:3128';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'http',
|
||||
port: 3128,
|
||||
// Note: The current implementation doesn't handle auth - if needed, add this functionality
|
||||
});
|
||||
});
|
||||
|
||||
test('throws error when proxy URL is invalid', () => {
|
||||
process.env.proxy = 'invalid-url';
|
||||
|
||||
expect(() => createAxiosInstance()).toThrow('Invalid proxy URL');
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
// If you want to test the actual URL parsing more thoroughly
|
||||
test('handles edge case proxy URLs correctly', () => {
|
||||
// IPv6 address
|
||||
process.env.proxy = 'http://[::1]:8080';
|
||||
|
||||
let instance = createAxiosInstance();
|
||||
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: '::1',
|
||||
protocol: 'http',
|
||||
port: 8080,
|
||||
});
|
||||
|
||||
// URL with path (which should be ignored for proxy config)
|
||||
process.env.proxy = 'http://proxy.example.com:8080/some/path';
|
||||
|
||||
instance = createAxiosInstance();
|
||||
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'http',
|
||||
port: 8080,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,67 +1,20 @@
|
|||
/** @type {import('jest').Config} */
|
||||
module.exports = {
|
||||
// Define separate Jest projects
|
||||
projects: [
|
||||
// Default config for most tests
|
||||
{
|
||||
displayName: 'default',
|
||||
testEnvironment: 'node',
|
||||
clearMocks: true,
|
||||
roots: ['<rootDir>'],
|
||||
coverageDirectory: 'coverage',
|
||||
setupFiles: [
|
||||
'./test/jestSetup.js',
|
||||
'./test/__mocks__/logger.js',
|
||||
'./test/__mocks__/fetchEventSource.js',
|
||||
],
|
||||
moduleNameMapper: {
|
||||
'~/(.*)': '<rootDir>/$1',
|
||||
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
|
||||
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js',
|
||||
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
|
||||
},
|
||||
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],
|
||||
// testMatch: ['<rootDir>/**/*.spec.js', '<rootDir>/**/*.spec.ts'],
|
||||
testPathIgnorePatterns: [
|
||||
'<rootDir>/strategies/openidStrategy.spec.js',
|
||||
'<rootDir>/strategies/samlStrategy.spec.js',
|
||||
'<rootDir>/strategies/appleStrategy.test.js',
|
||||
],
|
||||
},
|
||||
|
||||
// Special config just for openidStrategy.spec.js
|
||||
{
|
||||
displayName: 'openid-strategy',
|
||||
testEnvironment: 'node',
|
||||
clearMocks: true,
|
||||
setupFiles: [
|
||||
'./test/jestSetup.js',
|
||||
'./test/__mocks__/logger.js',
|
||||
'./test/__mocks__/fetchEventSource.js',
|
||||
],
|
||||
moduleNameMapper: {
|
||||
'~/(.*)': '<rootDir>/$1',
|
||||
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
|
||||
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js',
|
||||
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
|
||||
},
|
||||
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],
|
||||
transform: {
|
||||
'^.+\\.tsx?$': [
|
||||
'ts-jest',
|
||||
{
|
||||
tsconfig: {
|
||||
esModuleInterop: true,
|
||||
allowSyntheticDefaultImports: true,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
testMatch: [
|
||||
'<rootDir>/strategies/openidStrategy.spec.js',
|
||||
'<rootDir>/strategies/samlStrategy.spec.js',
|
||||
'<rootDir>/strategies/appleStrategy.test.js',
|
||||
],
|
||||
},
|
||||
displayName: 'default',
|
||||
testEnvironment: 'node',
|
||||
clearMocks: true,
|
||||
roots: ['<rootDir>'],
|
||||
coverageDirectory: 'coverage',
|
||||
setupFiles: [
|
||||
'./test/jestSetup.js',
|
||||
'./test/__mocks__/logger.js',
|
||||
'./test/__mocks__/fetchEventSource.js',
|
||||
],
|
||||
moduleNameMapper: {
|
||||
'~/(.*)': '<rootDir>/$1',
|
||||
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
|
||||
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js',
|
||||
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
|
||||
},
|
||||
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],
|
||||
};
|
||||
|
|
|
|||
|
|
@ -170,7 +170,6 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
|
|||
'created_at',
|
||||
'updated_at',
|
||||
'__v',
|
||||
'agent_ids',
|
||||
'versions',
|
||||
'actionsHash', // Exclude actionsHash from direct comparison
|
||||
];
|
||||
|
|
@ -260,11 +259,12 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
|
|||
* @param {Object} [options] - Optional configuration object.
|
||||
* @param {string} [options.updatingUserId] - The ID of the user performing the update (used for tracking non-author updates).
|
||||
* @param {boolean} [options.forceVersion] - Force creation of a new version even if no fields changed.
|
||||
* @param {boolean} [options.skipVersioning] - Skip version creation entirely (useful for isolated operations like sharing).
|
||||
* @returns {Promise<Agent>} The updated or newly created agent document as a plain object.
|
||||
* @throws {Error} If the update would create a duplicate version
|
||||
*/
|
||||
const updateAgent = async (searchParameter, updateData, options = {}) => {
|
||||
const { updatingUserId = null, forceVersion = false } = options;
|
||||
const { updatingUserId = null, forceVersion = false, skipVersioning = false } = options;
|
||||
const mongoOptions = { new: true, upsert: false };
|
||||
|
||||
const currentAgent = await Agent.findOne(searchParameter);
|
||||
|
|
@ -301,10 +301,8 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
|
|||
}
|
||||
|
||||
const shouldCreateVersion =
|
||||
forceVersion ||
|
||||
(versions &&
|
||||
versions.length > 0 &&
|
||||
(Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet));
|
||||
!skipVersioning &&
|
||||
(forceVersion || Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet);
|
||||
|
||||
if (shouldCreateVersion) {
|
||||
const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash);
|
||||
|
|
@ -339,7 +337,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
|
|||
versionEntry.updatedBy = new mongoose.Types.ObjectId(updatingUserId);
|
||||
}
|
||||
|
||||
if (shouldCreateVersion || forceVersion) {
|
||||
if (shouldCreateVersion) {
|
||||
updateData.$push = {
|
||||
...($push || {}),
|
||||
versions: versionEntry,
|
||||
|
|
@ -550,7 +548,10 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds
|
|||
delete updateQuery.author;
|
||||
}
|
||||
|
||||
const updatedAgent = await updateAgent(updateQuery, updateOps, { updatingUserId: user.id });
|
||||
const updatedAgent = await updateAgent(updateQuery, updateOps, {
|
||||
updatingUserId: user.id,
|
||||
skipVersioning: true,
|
||||
});
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -12,6 +12,10 @@ const comparePassword = async (user, candidatePassword) => {
|
|||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
if (!user.password) {
|
||||
throw new Error('No password, likely an email first registered via Social/OIDC login');
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
bcrypt.compare(candidatePassword, user.password, (err, isMatch) => {
|
||||
if (err) {
|
||||
|
|
|
|||
|
|
@ -48,8 +48,9 @@
|
|||
"@langchain/google-genai": "^0.2.9",
|
||||
"@langchain/google-vertexai": "^0.2.9",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.37",
|
||||
"@librechat/agents": "^2.4.38",
|
||||
"@librechat/auth": "*",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@node-saml/passport-saml": "^5.0.0",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
|
|
@ -82,15 +83,15 @@
|
|||
"keyv-file": "^5.1.2",
|
||||
"klona": "^2.0.6",
|
||||
"librechat-data-provider": "*",
|
||||
"librechat-mcp": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"meilisearch": "^0.38.0",
|
||||
"memorystore": "^1.6.7",
|
||||
"mime": "^3.0.0",
|
||||
"module-alias": "^2.2.3",
|
||||
"mongoose": "^8.12.1",
|
||||
"multer": "^2.0.0",
|
||||
"multer": "^2.0.1",
|
||||
"nanoid": "^3.3.7",
|
||||
"node-fetch": "^2.7.0",
|
||||
"nodemailer": "^6.9.15",
|
||||
"ollama": "^0.5.0",
|
||||
"openai": "^4.96.2",
|
||||
|
|
@ -110,8 +111,9 @@
|
|||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"undici": "^7.10.0",
|
||||
"winston": "^3.11.0",
|
||||
"winston-daily-rotate-file": "^4.7.1",
|
||||
"winston-daily-rotate-file": "^5.0.0",
|
||||
"youtube-transcript": "^1.2.1",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -220,6 +220,9 @@ function disposeClient(client) {
|
|||
if (client.maxResponseTokens) {
|
||||
client.maxResponseTokens = null;
|
||||
}
|
||||
if (client.processMemory) {
|
||||
client.processMemory = null;
|
||||
}
|
||||
if (client.run) {
|
||||
// Break circular references in run
|
||||
if (client.run.Graph) {
|
||||
|
|
|
|||
|
|
@ -163,7 +163,11 @@ const deleteUserController = async (req, res) => {
|
|||
await Balance.deleteMany({ user: user._id }); // delete user balances
|
||||
await deletePresets(user.id); // delete user presets
|
||||
/* TODO: Delete Assistant Threads */
|
||||
await deleteConvos(user.id); // delete user convos
|
||||
try {
|
||||
await deleteConvos(user.id); // delete user convos
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserController] Error deleting user convos, likely no convos', error);
|
||||
}
|
||||
await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth
|
||||
await deleteUserById(user.id); // delete user
|
||||
await deleteAllSharedLinks(user.id); // delete user shared links
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, StepTypes, FileContext } = require('librechat-data-provider');
|
||||
const {
|
||||
EnvVar,
|
||||
|
|
@ -12,7 +14,6 @@ const {
|
|||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { saveBase64Image } = require('~/server/services/Files/process');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
|
||||
class ModelEndHandler {
|
||||
/**
|
||||
|
|
@ -240,9 +241,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
if (output.artifact[Tools.web_search]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
const name = `${output.name}_${output.tool_call_id}_${nanoid()}`;
|
||||
const attachment = {
|
||||
name,
|
||||
type: Tools.web_search,
|
||||
messageId: metadata.run_id,
|
||||
toolCallId: output.tool_call_id,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
// const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
// const {
|
||||
// Constants,
|
||||
// ImageDetail,
|
||||
// EModelEndpoint,
|
||||
// resolveHeaders,
|
||||
// validateVisionModel,
|
||||
// mapModelToAzureConfig,
|
||||
// } = require('librechat-data-provider');
|
||||
require('events').EventEmitter.defaultMaxListeners = 100;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
sendEvent,
|
||||
createRun,
|
||||
Tokenizer,
|
||||
memoryInstructions,
|
||||
createMemoryProcessor,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Callback,
|
||||
GraphEvents,
|
||||
|
|
@ -19,25 +18,30 @@ const {
|
|||
} = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
Permissions,
|
||||
VisionModes,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
KnownEndpoints,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const { setMemory, deleteMemory, getFormattedMemories } = require('~/models');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { checkAccess } = require('~/server/middleware/roles/access');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
const { createRun } = require('./run');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
|
|
@ -57,12 +61,8 @@ const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deep
|
|||
|
||||
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
|
||||
|
||||
// const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory');
|
||||
// const { getFormattedMemories } = require('~/models/Memory');
|
||||
// const { getCurrentDateTime } = require('~/utils');
|
||||
|
||||
function createTokenCounter(encoding) {
|
||||
return (message) => {
|
||||
return function (message) {
|
||||
const countTokens = (text) => Tokenizer.getTokenCount(text, encoding);
|
||||
return getTokenCountForMessage(message, countTokens);
|
||||
};
|
||||
|
|
@ -123,6 +123,8 @@ class AgentClient extends BaseClient {
|
|||
this.usage;
|
||||
/** @type {Record<string, number>} */
|
||||
this.indexTokenCountMap = {};
|
||||
/** @type {(messages: BaseMessage[]) => Promise<void>} */
|
||||
this.processMemory;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -137,55 +139,10 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Checks if the model is a vision model based on request attachments and sets the appropriate options:
|
||||
* - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request.
|
||||
* - Sets `this.isVisionModel` to `true` if vision request.
|
||||
* - Deletes `this.modelOptions.stop` if vision request.
|
||||
* `AgentClient` is not opinionated about vision requests, so we don't do anything here
|
||||
* @param {MongoFile[]} attachments
|
||||
*/
|
||||
checkVisionRequest(attachments) {
|
||||
// if (!attachments) {
|
||||
// return;
|
||||
// }
|
||||
// const availableModels = this.options.modelsConfig?.[this.options.endpoint];
|
||||
// if (!availableModels) {
|
||||
// return;
|
||||
// }
|
||||
// let visionRequestDetected = false;
|
||||
// for (const file of attachments) {
|
||||
// if (file?.type?.includes('image')) {
|
||||
// visionRequestDetected = true;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// if (!visionRequestDetected) {
|
||||
// return;
|
||||
// }
|
||||
// this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
||||
// if (this.isVisionModel) {
|
||||
// delete this.modelOptions.stop;
|
||||
// return;
|
||||
// }
|
||||
// for (const model of availableModels) {
|
||||
// if (!validateVisionModel({ model, availableModels })) {
|
||||
// continue;
|
||||
// }
|
||||
// this.modelOptions.model = model;
|
||||
// this.isVisionModel = true;
|
||||
// delete this.modelOptions.stop;
|
||||
// return;
|
||||
// }
|
||||
// if (!availableModels.includes(this.defaultVisionModel)) {
|
||||
// return;
|
||||
// }
|
||||
// if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) {
|
||||
// return;
|
||||
// }
|
||||
// this.modelOptions.model = this.defaultVisionModel;
|
||||
// this.isVisionModel = true;
|
||||
// delete this.modelOptions.stop;
|
||||
}
|
||||
checkVisionRequest() {}
|
||||
|
||||
getSaveOptions() {
|
||||
// TODO:
|
||||
|
|
@ -269,24 +226,6 @@ class AgentClient extends BaseClient {
|
|||
.filter(Boolean)
|
||||
.join('\n')
|
||||
.trim();
|
||||
// this.systemMessage = getCurrentDateTime();
|
||||
// const { withKeys, withoutKeys } = await getFormattedMemories({
|
||||
// userId: this.options.req.user.id,
|
||||
// });
|
||||
// processMemory({
|
||||
// userId: this.options.req.user.id,
|
||||
// message: this.options.req.body.text,
|
||||
// parentMessageId,
|
||||
// memory: withKeys,
|
||||
// thread_id: this.conversationId,
|
||||
// }).catch((error) => {
|
||||
// logger.error('Memory Agent failed to process memory', error);
|
||||
// });
|
||||
|
||||
// this.systemMessage += '\n\n' + memoryInstructions;
|
||||
// if (withoutKeys) {
|
||||
// this.systemMessage += `\n\n# Existing memory about the user:\n${withoutKeys}`;
|
||||
// }
|
||||
|
||||
if (this.options.attachments) {
|
||||
const attachments = await this.options.attachments;
|
||||
|
|
@ -370,6 +309,37 @@ class AgentClient extends BaseClient {
|
|||
systemContent = this.augmentedPrompt + systemContent;
|
||||
}
|
||||
|
||||
// Inject MCP server instructions if available
|
||||
const ephemeralAgent = this.options.req.body.ephemeralAgent;
|
||||
let mcpServers = [];
|
||||
|
||||
// Check for ephemeral agent MCP servers
|
||||
if (ephemeralAgent && ephemeralAgent.mcp && ephemeralAgent.mcp.length > 0) {
|
||||
mcpServers = ephemeralAgent.mcp;
|
||||
}
|
||||
// Check for regular agent MCP tools
|
||||
else if (this.options.agent && this.options.agent.tools) {
|
||||
mcpServers = this.options.agent.tools
|
||||
.filter(
|
||||
(tool) =>
|
||||
tool instanceof DynamicStructuredTool && tool.name.includes(Constants.mcp_delimiter),
|
||||
)
|
||||
.map((tool) => tool.name.split(Constants.mcp_delimiter).pop())
|
||||
.filter(Boolean);
|
||||
}
|
||||
|
||||
if (mcpServers.length > 0) {
|
||||
try {
|
||||
const mcpInstructions = getMCPManager().formatInstructionsForContext(mcpServers);
|
||||
if (mcpInstructions) {
|
||||
systemContent = [systemContent, mcpInstructions].filter(Boolean).join('\n\n');
|
||||
logger.debug('[AgentClient] Injected MCP instructions for servers:', mcpServers);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Failed to inject MCP instructions:', error);
|
||||
}
|
||||
}
|
||||
|
||||
if (systemContent) {
|
||||
this.options.agent.instructions = systemContent;
|
||||
}
|
||||
|
|
@ -399,9 +369,150 @@ class AgentClient extends BaseClient {
|
|||
opts.getReqData({ promptTokens });
|
||||
}
|
||||
|
||||
const withoutKeys = await this.useMemory();
|
||||
if (withoutKeys) {
|
||||
systemContent += `${memoryInstructions}\n\n# Existing memory about the user:\n${withoutKeys}`;
|
||||
}
|
||||
|
||||
if (systemContent) {
|
||||
this.options.agent.instructions = systemContent;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<string | undefined>}
|
||||
*/
|
||||
async useMemory() {
|
||||
const user = this.options.req.user;
|
||||
if (user.personalization?.memories === false) {
|
||||
return;
|
||||
}
|
||||
const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
|
||||
|
||||
if (!hasAccess) {
|
||||
logger.debug(
|
||||
`[api/server/controllers/agents/client.js #useMemory] User ${user.id} does not have USE permission for memories`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
/** @type {TCustomConfig['memory']} */
|
||||
const memoryConfig = this.options.req?.app?.locals?.memory;
|
||||
if (!memoryConfig || memoryConfig.disabled === true) {
|
||||
return;
|
||||
}
|
||||
|
||||
/** @type {Agent} */
|
||||
let prelimAgent;
|
||||
const allowedProviders = new Set(
|
||||
this.options.req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders,
|
||||
);
|
||||
try {
|
||||
if (memoryConfig.agent?.id != null && memoryConfig.agent.id !== this.options.agent.id) {
|
||||
prelimAgent = await loadAgent({
|
||||
req: this.options.req,
|
||||
agent_id: memoryConfig.agent.id,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
});
|
||||
} else if (
|
||||
memoryConfig.agent?.id == null &&
|
||||
memoryConfig.agent?.model != null &&
|
||||
memoryConfig.agent?.provider != null
|
||||
) {
|
||||
prelimAgent = { id: Constants.EPHEMERAL_AGENT_ID, ...memoryConfig.agent };
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #useMemory] Error loading agent for memory',
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
const agent = await initializeAgent({
|
||||
req: this.options.req,
|
||||
res: this.options.res,
|
||||
agent: prelimAgent,
|
||||
allowedProviders,
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
logger.warn(
|
||||
'[api/server/controllers/agents/client.js #useMemory] No agent found for memory',
|
||||
memoryConfig,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const llmConfig = Object.assign(
|
||||
{
|
||||
provider: agent.provider,
|
||||
model: agent.model,
|
||||
},
|
||||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** @type {import('@librechat/api').MemoryConfig} */
|
||||
const config = {
|
||||
validKeys: memoryConfig.validKeys,
|
||||
instructions: agent.instructions,
|
||||
llmConfig,
|
||||
tokenLimit: memoryConfig.tokenLimit,
|
||||
};
|
||||
|
||||
const userId = this.options.req.user.id + '';
|
||||
const messageId = this.responseMessageId + '';
|
||||
const conversationId = this.conversationId + '';
|
||||
const [withoutKeys, processMemory] = await createMemoryProcessor({
|
||||
userId,
|
||||
config,
|
||||
messageId,
|
||||
conversationId,
|
||||
memoryMethods: {
|
||||
setMemory,
|
||||
deleteMemory,
|
||||
getFormattedMemories,
|
||||
},
|
||||
res: this.options.res,
|
||||
});
|
||||
|
||||
this.processMemory = processMemory;
|
||||
return withoutKeys;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {BaseMessage[]} messages
|
||||
* @returns {Promise<void | (TAttachment | null)[]>}
|
||||
*/
|
||||
async runMemory(messages) {
|
||||
try {
|
||||
if (this.processMemory == null) {
|
||||
return;
|
||||
}
|
||||
/** @type {TCustomConfig['memory']} */
|
||||
const memoryConfig = this.options.req?.app?.locals?.memory;
|
||||
const messageWindowSize = memoryConfig?.messageWindowSize ?? 5;
|
||||
|
||||
let messagesToProcess = [...messages];
|
||||
if (messages.length > messageWindowSize) {
|
||||
for (let i = messages.length - messageWindowSize; i >= 0; i--) {
|
||||
const potentialWindow = messages.slice(i, i + messageWindowSize);
|
||||
if (potentialWindow[0]?.role === 'user') {
|
||||
messagesToProcess = [...potentialWindow];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (messagesToProcess.length === messages.length) {
|
||||
messagesToProcess = [...messages.slice(-messageWindowSize)];
|
||||
}
|
||||
}
|
||||
return await this.processMemory(messagesToProcess);
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {sendCompletion} */
|
||||
async sendCompletion(payload, opts = {}) {
|
||||
await this.chatCompletion({
|
||||
|
|
@ -544,100 +655,13 @@ class AgentClient extends BaseClient {
|
|||
let config;
|
||||
/** @type {ReturnType<createRun>} */
|
||||
let run;
|
||||
/** @type {Promise<(TAttachment | null)[] | undefined>} */
|
||||
let memoryPromise;
|
||||
try {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
|
||||
// if (this.options.headers) {
|
||||
// opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers };
|
||||
// }
|
||||
|
||||
// if (this.options.proxy) {
|
||||
// opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
// }
|
||||
|
||||
// if (this.isVisionModel) {
|
||||
// modelOptions.max_tokens = 4000;
|
||||
// }
|
||||
|
||||
// /** @type {TAzureConfig | undefined} */
|
||||
// const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
|
||||
|
||||
// if (
|
||||
// (this.azure && this.isVisionModel && azureConfig) ||
|
||||
// (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
|
||||
// ) {
|
||||
// const { modelGroupMap, groupMap } = azureConfig;
|
||||
// const {
|
||||
// azureOptions,
|
||||
// baseURL,
|
||||
// headers = {},
|
||||
// serverless,
|
||||
// } = mapModelToAzureConfig({
|
||||
// modelName: modelOptions.model,
|
||||
// modelGroupMap,
|
||||
// groupMap,
|
||||
// });
|
||||
// opts.defaultHeaders = resolveHeaders(headers);
|
||||
// this.langchainProxy = extractBaseURL(baseURL);
|
||||
// this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
|
||||
// const groupName = modelGroupMap[modelOptions.model].group;
|
||||
// this.options.addParams = azureConfig.groupMap[groupName].addParams;
|
||||
// this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
|
||||
// // Note: `forcePrompt` not re-assigned as only chat models are vision models
|
||||
|
||||
// this.azure = !serverless && azureOptions;
|
||||
// this.azureEndpoint =
|
||||
// !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
|
||||
// }
|
||||
|
||||
// if (this.azure || this.options.azure) {
|
||||
// /* Azure Bug, extremely short default `max_tokens` response */
|
||||
// if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') {
|
||||
// modelOptions.max_tokens = 4000;
|
||||
// }
|
||||
|
||||
// /* Azure does not accept `model` in the body, so we need to remove it. */
|
||||
// delete modelOptions.model;
|
||||
|
||||
// opts.baseURL = this.langchainProxy
|
||||
// ? constructAzureURL({
|
||||
// baseURL: this.langchainProxy,
|
||||
// azureOptions: this.azure,
|
||||
// })
|
||||
// : this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
|
||||
|
||||
// opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
|
||||
// opts.defaultHeaders = { ...opts.defaultHeaders, 'api-key': this.apiKey };
|
||||
// }
|
||||
|
||||
// if (process.env.OPENAI_ORGANIZATION) {
|
||||
// opts.organization = process.env.OPENAI_ORGANIZATION;
|
||||
// }
|
||||
|
||||
// if (this.options.addParams && typeof this.options.addParams === 'object') {
|
||||
// modelOptions = {
|
||||
// ...modelOptions,
|
||||
// ...this.options.addParams,
|
||||
// };
|
||||
// logger.debug('[api/server/controllers/agents/client.js #chatCompletion] added params', {
|
||||
// addParams: this.options.addParams,
|
||||
// modelOptions,
|
||||
// });
|
||||
// }
|
||||
|
||||
// if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
|
||||
// this.options.dropParams.forEach((param) => {
|
||||
// delete modelOptions[param];
|
||||
// });
|
||||
// logger.debug('[api/server/controllers/agents/client.js #chatCompletion] dropped params', {
|
||||
// dropParams: this.options.dropParams,
|
||||
// modelOptions,
|
||||
// });
|
||||
// }
|
||||
|
||||
/** @type {TCustomConfig['endpoints']['agents']} */
|
||||
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
|
||||
|
||||
|
|
@ -647,6 +671,7 @@ class AgentClient extends BaseClient {
|
|||
last_agent_index: this.agentConfigs?.size ?? 0,
|
||||
user_id: this.user ?? this.options.req.user?.id,
|
||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||
user: this.options.req.user,
|
||||
},
|
||||
recursionLimit: agentsEConfig?.recursionLimit,
|
||||
signal: abortController.signal,
|
||||
|
|
@ -734,6 +759,10 @@ class AgentClient extends BaseClient {
|
|||
messages = addCacheControl(messages);
|
||||
}
|
||||
|
||||
if (i === 0) {
|
||||
memoryPromise = this.runMemory(messages);
|
||||
}
|
||||
|
||||
run = await createRun({
|
||||
agent,
|
||||
req: this.options.req,
|
||||
|
|
@ -769,10 +798,9 @@ class AgentClient extends BaseClient {
|
|||
run.Graph.contentData = contentData;
|
||||
}
|
||||
|
||||
const encoding = this.getEncoding();
|
||||
await run.processStream({ messages }, config, {
|
||||
keepContent: i !== 0,
|
||||
tokenCounter: createTokenCounter(encoding),
|
||||
tokenCounter: createTokenCounter(this.getEncoding()),
|
||||
indexTokenCountMap: currentIndexCountMap,
|
||||
maxContextTokens: agent.maxContextTokens,
|
||||
callbacks: {
|
||||
|
|
@ -887,6 +915,12 @@ class AgentClient extends BaseClient {
|
|||
});
|
||||
|
||||
try {
|
||||
if (memoryPromise) {
|
||||
const attachments = await memoryPromise;
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
}
|
||||
await this.recordCollectedUsage({ context: 'message' });
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
|
|
@ -895,6 +929,12 @@ class AgentClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
} catch (err) {
|
||||
if (memoryPromise) {
|
||||
const attachments = await memoryPromise;
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
}
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
err,
|
||||
|
|
|
|||
|
|
@ -1,94 +0,0 @@
|
|||
const { Run, Providers } = require('@librechat/agents');
|
||||
const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* @typedef {import('@librechat/agents').t} t
|
||||
* @typedef {import('@librechat/agents').StandardGraphConfig} StandardGraphConfig
|
||||
* @typedef {import('@librechat/agents').StreamEventData} StreamEventData
|
||||
* @typedef {import('@librechat/agents').EventHandler} EventHandler
|
||||
* @typedef {import('@librechat/agents').GraphEvents} GraphEvents
|
||||
* @typedef {import('@librechat/agents').LLMConfig} LLMConfig
|
||||
* @typedef {import('@librechat/agents').IState} IState
|
||||
*/
|
||||
|
||||
const customProviders = new Set([
|
||||
Providers.XAI,
|
||||
Providers.OLLAMA,
|
||||
Providers.DEEPSEEK,
|
||||
Providers.OPENROUTER,
|
||||
]);
|
||||
|
||||
/**
|
||||
* Creates a new Run instance with custom handlers and configuration.
|
||||
*
|
||||
* @param {Object} options - The options for creating the Run instance.
|
||||
* @param {ServerRequest} [options.req] - The server request.
|
||||
* @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated.
|
||||
* @param {Agent} options.agent - The agent for this run.
|
||||
* @param {AbortSignal} options.signal - The signal for this run.
|
||||
* @param {Record<GraphEvents, EventHandler> | undefined} [options.customHandlers] - Custom event handlers.
|
||||
* @param {boolean} [options.streaming=true] - Whether to use streaming.
|
||||
* @param {boolean} [options.streamUsage=true] - Whether to stream usage information.
|
||||
* @returns {Promise<Run<IState>>} A promise that resolves to a new Run instance.
|
||||
*/
|
||||
async function createRun({
|
||||
runId,
|
||||
agent,
|
||||
signal,
|
||||
customHandlers,
|
||||
streaming = true,
|
||||
streamUsage = true,
|
||||
}) {
|
||||
const provider = providerEndpointMap[agent.provider] ?? agent.provider;
|
||||
/** @type {LLMConfig} */
|
||||
const llmConfig = Object.assign(
|
||||
{
|
||||
provider,
|
||||
streaming,
|
||||
streamUsage,
|
||||
},
|
||||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** Resolves issues with new OpenAI usage field */
|
||||
if (
|
||||
customProviders.has(agent.provider) ||
|
||||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
|
||||
) {
|
||||
llmConfig.streamUsage = false;
|
||||
llmConfig.usage = true;
|
||||
}
|
||||
|
||||
/** @type {'reasoning_content' | 'reasoning'} */
|
||||
let reasoningKey;
|
||||
if (
|
||||
llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) ||
|
||||
(agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
reasoningKey = 'reasoning';
|
||||
}
|
||||
|
||||
/** @type {StandardGraphConfig} */
|
||||
const graphConfig = {
|
||||
signal,
|
||||
llmConfig,
|
||||
reasoningKey,
|
||||
tools: agent.tools,
|
||||
instructions: agent.instructions,
|
||||
additional_instructions: agent.additional_instructions,
|
||||
// toolEnd: agent.end_after_tools,
|
||||
};
|
||||
|
||||
// TEMPORARY FOR TESTING
|
||||
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
|
||||
graphConfig.streamBuffer = 2000;
|
||||
}
|
||||
|
||||
return Run.create({
|
||||
runId,
|
||||
graphConfig,
|
||||
customHandlers,
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = { createRun };
|
||||
|
|
@ -18,6 +18,7 @@ const {
|
|||
} = require('~/models/Agent');
|
||||
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { resizeAvatar } = require('@librechat/auth');
|
||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||
const { updateAction, getActions } = require('~/models/Action');
|
||||
const { updateAgentProjects } = require('~/models/Agent');
|
||||
|
|
@ -168,12 +169,18 @@ const updateAgentHandler = async (req, res) => {
|
|||
});
|
||||
}
|
||||
|
||||
/** @type {boolean} */
|
||||
const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0;
|
||||
|
||||
let updatedAgent =
|
||||
Object.keys(updateData).length > 0
|
||||
? await updateAgent({ id }, updateData, { updatingUserId: req.user.id })
|
||||
? await updateAgent({ id }, updateData, {
|
||||
updatingUserId: req.user.id,
|
||||
skipVersioning: isProjectUpdate,
|
||||
})
|
||||
: existingAgent;
|
||||
|
||||
if (projectIds || removeProjectIds) {
|
||||
if (isProjectUpdate) {
|
||||
updatedAgent = await updateAgentProjects({
|
||||
user: req.user,
|
||||
agentId: id,
|
||||
|
|
@ -373,12 +380,27 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
|||
}
|
||||
|
||||
const buffer = await fs.readFile(req.file.path);
|
||||
const image = await uploadImageBuffer({
|
||||
req,
|
||||
context: FileContext.avatar,
|
||||
metadata: { buffer },
|
||||
|
||||
const fileStrategy = req.app.locals.fileStrategy;
|
||||
|
||||
const resizedBuffer = await resizeAvatar({
|
||||
userId: req.user.id,
|
||||
input: buffer,
|
||||
});
|
||||
|
||||
const { processAvatar } = getStrategyFunctions(fileStrategy);
|
||||
const avatarUrl = await processAvatar({
|
||||
buffer: resizedBuffer,
|
||||
userId: req.user.id,
|
||||
manual: 'false',
|
||||
agentId: agent_id,
|
||||
});
|
||||
|
||||
const image = {
|
||||
filepath: avatarUrl,
|
||||
source: fileStrategy,
|
||||
};
|
||||
|
||||
let _avatar;
|
||||
try {
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
|
|
@ -403,7 +425,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
|||
const data = {
|
||||
avatar: {
|
||||
filepath: image.filepath,
|
||||
source: req.app.locals.fileStrategy,
|
||||
source: image.source,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ const startServer = async () => {
|
|||
app.use('/api/agents', routes.agents);
|
||||
app.use('/api/banner', routes.banner);
|
||||
app.use('/api/bedrock', routes.bedrock);
|
||||
|
||||
app.use('/api/memories', routes.memories);
|
||||
app.use('/api/tags', routes.tags);
|
||||
|
||||
app.use((req, res) => {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const checkAdmin = require('./checkAdmin');
|
||||
const { checkAccess, generateCheckAccess } = require('./generateCheckAccess');
|
||||
const checkAdmin = require('./admin');
|
||||
const { checkAccess, generateCheckAccess } = require('./access');
|
||||
|
||||
module.exports = {
|
||||
checkAdmin,
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
identifier,
|
||||
client_url: flowState.metadata.client_url,
|
||||
redirect_uri: flowState.metadata.redirect_uri,
|
||||
token_exchange_method: flowState.metadata.token_exchange_method,
|
||||
/** Encrypted values */
|
||||
encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret,
|
||||
|
|
|
|||
|
|
@ -65,8 +65,14 @@ router.post('/gen_title', async (req, res) => {
|
|||
let title = await titleCache.get(key);
|
||||
|
||||
if (!title) {
|
||||
await sleep(2500);
|
||||
title = await titleCache.get(key);
|
||||
// Retry every 1s for up to 20s
|
||||
for (let i = 0; i < 20; i++) {
|
||||
await sleep(1000);
|
||||
title = await titleCache.get(key);
|
||||
if (title) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (title) {
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ const fs = require('fs');
|
|||
const path = require('path');
|
||||
const crypto = require('crypto');
|
||||
const multer = require('multer');
|
||||
const { sanitizeFilename } = require('@librechat/api');
|
||||
const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider');
|
||||
const { sanitizeFilename } = require('~/server/utils/handleText');
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
|
||||
const storage = multer.diskStorage({
|
||||
|
|
|
|||
571
api/server/routes/files/multer.spec.js
Normal file
571
api/server/routes/files/multer.spec.js
Normal file
|
|
@ -0,0 +1,571 @@
|
|||
/* eslint-disable no-unused-vars */
|
||||
/* eslint-disable jest/no-done-callback */
|
||||
const fs = require('fs');
|
||||
const os = require('os');
|
||||
const path = require('path');
|
||||
const crypto = require('crypto');
|
||||
const { createMulterInstance, storage, importFileFilter } = require('./multer');
|
||||
|
||||
// Mock only the config service that requires external dependencies
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCustomConfig: jest.fn(() =>
|
||||
Promise.resolve({
|
||||
fileConfig: {
|
||||
endpoints: {
|
||||
openAI: {
|
||||
supportedMimeTypes: ['image/jpeg', 'image/png', 'application/pdf'],
|
||||
},
|
||||
default: {
|
||||
supportedMimeTypes: ['image/jpeg', 'image/png', 'text/plain'],
|
||||
},
|
||||
},
|
||||
serverFileSizeLimit: 10000000, // 10MB
|
||||
},
|
||||
}),
|
||||
),
|
||||
}));
|
||||
|
||||
describe('Multer Configuration', () => {
|
||||
let tempDir;
|
||||
let mockReq;
|
||||
let mockFile;
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a temporary directory for each test
|
||||
tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'multer-test-'));
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'test-user-123' },
|
||||
app: {
|
||||
locals: {
|
||||
paths: {
|
||||
uploads: tempDir,
|
||||
},
|
||||
},
|
||||
},
|
||||
body: {},
|
||||
originalUrl: '/api/files/upload',
|
||||
};
|
||||
|
||||
mockFile = {
|
||||
originalname: 'test-file.jpg',
|
||||
mimetype: 'image/jpeg',
|
||||
size: 1024,
|
||||
};
|
||||
|
||||
// Clear mocks
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Clean up temporary directory
|
||||
if (fs.existsSync(tempDir)) {
|
||||
fs.rmSync(tempDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
describe('Storage Configuration', () => {
|
||||
describe('destination function', () => {
|
||||
it('should create the correct destination path', (done) => {
|
||||
const cb = jest.fn((err, destination) => {
|
||||
expect(err).toBeNull();
|
||||
expect(destination).toBe(path.join(tempDir, 'temp', 'test-user-123'));
|
||||
expect(fs.existsSync(destination)).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getDestination(mockReq, mockFile, cb);
|
||||
});
|
||||
|
||||
it("should create directory recursively if it doesn't exist", (done) => {
|
||||
const deepPath = path.join(tempDir, 'deep', 'nested', 'path');
|
||||
mockReq.app.locals.paths.uploads = deepPath;
|
||||
|
||||
const cb = jest.fn((err, destination) => {
|
||||
expect(err).toBeNull();
|
||||
expect(destination).toBe(path.join(deepPath, 'temp', 'test-user-123'));
|
||||
expect(fs.existsSync(destination)).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getDestination(mockReq, mockFile, cb);
|
||||
});
|
||||
});
|
||||
|
||||
describe('filename function', () => {
|
||||
it('should generate a UUID for req.file_id', (done) => {
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(mockReq.file_id).toBeDefined();
|
||||
expect(mockReq.file_id).toMatch(
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i,
|
||||
);
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
});
|
||||
|
||||
it('should decode URI components in filename', (done) => {
|
||||
const encodedFile = {
|
||||
...mockFile,
|
||||
originalname: encodeURIComponent('test file with spaces.jpg'),
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(encodedFile.originalname).toBe('test file with spaces.jpg');
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, encodedFile, cb);
|
||||
});
|
||||
|
||||
it('should call real sanitizeFilename with properly encoded filename', (done) => {
|
||||
// Test with a properly URI-encoded filename that needs sanitization
|
||||
const unsafeFile = {
|
||||
...mockFile,
|
||||
originalname: encodeURIComponent('test@#$%^&*()file with spaces!.jpg'),
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
// The actual sanitizeFilename should have cleaned this up after decoding
|
||||
expect(filename).not.toContain('@');
|
||||
expect(filename).not.toContain('#');
|
||||
expect(filename).not.toContain('*');
|
||||
expect(filename).not.toContain('!');
|
||||
// Should still preserve dots and hyphens
|
||||
expect(filename).toContain('.jpg');
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, unsafeFile, cb);
|
||||
});
|
||||
|
||||
it('should handle very long filenames with actual crypto', (done) => {
|
||||
const longFile = {
|
||||
...mockFile,
|
||||
originalname: 'a'.repeat(300) + '.jpg',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(filename.length).toBeLessThanOrEqual(255);
|
||||
expect(filename).toMatch(/\.jpg$/); // Should still end with .jpg
|
||||
// Should contain a hex suffix if truncated
|
||||
if (filename.length === 255) {
|
||||
expect(filename).toMatch(/-[a-f0-9]{6}\.jpg$/);
|
||||
}
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, longFile, cb);
|
||||
});
|
||||
|
||||
it('should generate unique file_id for each call', (done) => {
|
||||
let firstFileId;
|
||||
|
||||
const firstCb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
firstFileId = mockReq.file_id;
|
||||
|
||||
// Reset req for second call
|
||||
delete mockReq.file_id;
|
||||
|
||||
const secondCb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(mockReq.file_id).toBeDefined();
|
||||
expect(mockReq.file_id).not.toBe(firstFileId);
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, mockFile, secondCb);
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, mockFile, firstCb);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Import File Filter', () => {
|
||||
it('should accept JSON files by mimetype', (done) => {
|
||||
const jsonFile = {
|
||||
...mockFile,
|
||||
mimetype: 'application/json',
|
||||
originalname: 'data.json',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, result) => {
|
||||
expect(err).toBeNull();
|
||||
expect(result).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
importFileFilter(mockReq, jsonFile, cb);
|
||||
});
|
||||
|
||||
it('should accept files with .json extension', (done) => {
|
||||
const jsonFile = {
|
||||
...mockFile,
|
||||
mimetype: 'text/plain',
|
||||
originalname: 'data.json',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, result) => {
|
||||
expect(err).toBeNull();
|
||||
expect(result).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
importFileFilter(mockReq, jsonFile, cb);
|
||||
});
|
||||
|
||||
it('should reject non-JSON files', (done) => {
|
||||
const textFile = {
|
||||
...mockFile,
|
||||
mimetype: 'text/plain',
|
||||
originalname: 'document.txt',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, result) => {
|
||||
expect(err).toBeInstanceOf(Error);
|
||||
expect(err.message).toBe('Only JSON files are allowed');
|
||||
expect(result).toBe(false);
|
||||
done();
|
||||
});
|
||||
|
||||
importFileFilter(mockReq, textFile, cb);
|
||||
});
|
||||
|
||||
it('should handle files with uppercase .JSON extension', (done) => {
|
||||
const jsonFile = {
|
||||
...mockFile,
|
||||
mimetype: 'text/plain',
|
||||
originalname: 'DATA.JSON',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, result) => {
|
||||
expect(err).toBeNull();
|
||||
expect(result).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
importFileFilter(mockReq, jsonFile, cb);
|
||||
});
|
||||
});
|
||||
|
||||
describe('File Filter with Real defaultFileConfig', () => {
|
||||
it('should use real fileConfig.checkType for validation', async () => {
|
||||
// Test with actual librechat-data-provider functions
|
||||
const {
|
||||
fileConfig,
|
||||
imageMimeTypes,
|
||||
applicationMimeTypes,
|
||||
} = require('librechat-data-provider');
|
||||
|
||||
// Test that the real checkType function works with regex patterns
|
||||
expect(fileConfig.checkType('image/jpeg', [imageMimeTypes])).toBe(true);
|
||||
expect(fileConfig.checkType('video/mp4', [imageMimeTypes])).toBe(false);
|
||||
expect(fileConfig.checkType('application/pdf', [applicationMimeTypes])).toBe(true);
|
||||
expect(fileConfig.checkType('application/pdf', [])).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle audio files for speech-to-text endpoint with real config', async () => {
|
||||
mockReq.originalUrl = '/api/speech/stt';
|
||||
|
||||
const multerInstance = await createMulterInstance();
|
||||
expect(multerInstance).toBeDefined();
|
||||
expect(typeof multerInstance.single).toBe('function');
|
||||
});
|
||||
|
||||
it('should reject unsupported file types using real config', async () => {
|
||||
// Mock defaultFileConfig for this specific test
|
||||
const originalCheckType = require('librechat-data-provider').fileConfig.checkType;
|
||||
const mockCheckType = jest.fn().mockReturnValue(false);
|
||||
require('librechat-data-provider').fileConfig.checkType = mockCheckType;
|
||||
|
||||
try {
|
||||
const multerInstance = await createMulterInstance();
|
||||
expect(multerInstance).toBeDefined();
|
||||
|
||||
// Test the actual file filter behavior would reject unsupported files
|
||||
expect(mockCheckType).toBeDefined();
|
||||
} finally {
|
||||
// Restore original function
|
||||
require('librechat-data-provider').fileConfig.checkType = originalCheckType;
|
||||
}
|
||||
});
|
||||
|
||||
it('should use real mergeFileConfig function', async () => {
|
||||
const { mergeFileConfig, mbToBytes } = require('librechat-data-provider');
|
||||
|
||||
// Test with actual merge function - note that it converts MB to bytes
|
||||
const testConfig = {
|
||||
serverFileSizeLimit: 5, // 5 MB
|
||||
endpoints: {
|
||||
custom: {
|
||||
supportedMimeTypes: ['text/plain'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = mergeFileConfig(testConfig);
|
||||
|
||||
// The function converts MB to bytes, so 5 MB becomes 5 * 1024 * 1024 bytes
|
||||
expect(result.serverFileSizeLimit).toBe(mbToBytes(5));
|
||||
expect(result.endpoints.custom.supportedMimeTypes).toBeDefined();
|
||||
// Should still have the default endpoints
|
||||
expect(result.endpoints.default).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMulterInstance with Real Functions', () => {
|
||||
it('should create a multer instance with correct configuration', async () => {
|
||||
const multerInstance = await createMulterInstance();
|
||||
|
||||
expect(multerInstance).toBeDefined();
|
||||
expect(typeof multerInstance.single).toBe('function');
|
||||
expect(typeof multerInstance.array).toBe('function');
|
||||
expect(typeof multerInstance.fields).toBe('function');
|
||||
});
|
||||
|
||||
it('should use real config merging', async () => {
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
|
||||
const multerInstance = await createMulterInstance();
|
||||
|
||||
expect(getCustomConfig).toHaveBeenCalled();
|
||||
expect(multerInstance).toBeDefined();
|
||||
});
|
||||
|
||||
it('should create multer instance with expected interface', async () => {
|
||||
const multerInstance = await createMulterInstance();
|
||||
|
||||
expect(multerInstance).toBeDefined();
|
||||
expect(typeof multerInstance.single).toBe('function');
|
||||
expect(typeof multerInstance.array).toBe('function');
|
||||
expect(typeof multerInstance.fields).toBe('function');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Real Crypto Integration', () => {
|
||||
it('should use actual crypto.randomUUID()', (done) => {
|
||||
// Spy on crypto.randomUUID to ensure it's called
|
||||
const uuidSpy = jest.spyOn(crypto, 'randomUUID');
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(uuidSpy).toHaveBeenCalled();
|
||||
expect(mockReq.file_id).toMatch(
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i,
|
||||
);
|
||||
|
||||
uuidSpy.mockRestore();
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
});
|
||||
|
||||
it('should generate different UUIDs on subsequent calls', (done) => {
|
||||
const uuids = [];
|
||||
let callCount = 0;
|
||||
const totalCalls = 5;
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
uuids.push(mockReq.file_id);
|
||||
callCount++;
|
||||
|
||||
if (callCount === totalCalls) {
|
||||
// Check that all UUIDs are unique
|
||||
const uniqueUuids = new Set(uuids);
|
||||
expect(uniqueUuids.size).toBe(totalCalls);
|
||||
done();
|
||||
} else {
|
||||
// Reset for next call
|
||||
delete mockReq.file_id;
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
}
|
||||
});
|
||||
|
||||
// Start the chain
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
});
|
||||
|
||||
it('should generate cryptographically secure UUIDs', (done) => {
|
||||
const generatedUuids = new Set();
|
||||
let callCount = 0;
|
||||
const totalCalls = 10;
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
|
||||
// Verify UUID format and uniqueness
|
||||
expect(mockReq.file_id).toMatch(
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i,
|
||||
);
|
||||
|
||||
generatedUuids.add(mockReq.file_id);
|
||||
callCount++;
|
||||
|
||||
if (callCount === totalCalls) {
|
||||
// All UUIDs should be unique
|
||||
expect(generatedUuids.size).toBe(totalCalls);
|
||||
done();
|
||||
} else {
|
||||
// Reset for next call
|
||||
delete mockReq.file_id;
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
}
|
||||
});
|
||||
|
||||
// Start the chain
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle CVE-2024-28870: empty field name DoS vulnerability', async () => {
|
||||
// Test for the CVE where empty field name could cause unhandled exception
|
||||
const multerInstance = await createMulterInstance();
|
||||
|
||||
// Create a mock request with empty field name (the vulnerability scenario)
|
||||
const mockReqWithEmptyField = {
|
||||
...mockReq,
|
||||
headers: {
|
||||
'content-type': 'multipart/form-data',
|
||||
},
|
||||
};
|
||||
|
||||
const mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
end: jest.fn(),
|
||||
};
|
||||
|
||||
// This should not crash or throw unhandled exceptions
|
||||
const uploadMiddleware = multerInstance.single(''); // Empty field name
|
||||
|
||||
const mockNext = jest.fn((err) => {
|
||||
// If there's an error, it should be handled gracefully, not crash
|
||||
if (err) {
|
||||
expect(err).toBeInstanceOf(Error);
|
||||
// The error should be handled, not crash the process
|
||||
}
|
||||
});
|
||||
|
||||
// This should complete without crashing the process
|
||||
expect(() => {
|
||||
uploadMiddleware(mockReqWithEmptyField, mockRes, mockNext);
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle file system errors when directory creation fails', (done) => {
|
||||
// Test with a non-existent parent directory to simulate fs issues
|
||||
const invalidPath = '/nonexistent/path/that/should/not/exist';
|
||||
mockReq.app.locals.paths.uploads = invalidPath;
|
||||
|
||||
try {
|
||||
// Call getDestination which should fail due to permission/path issues
|
||||
storage.getDestination(mockReq, mockFile, (err, destination) => {
|
||||
// If callback is reached, we didn't get the expected error
|
||||
done(new Error('Expected mkdirSync to throw an error but callback was called'));
|
||||
});
|
||||
// If we get here without throwing, something unexpected happened
|
||||
done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
|
||||
} catch (error) {
|
||||
// This is the expected behavior - mkdirSync throws synchronously for invalid paths
|
||||
expect(error.code).toBe('EACCES');
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle malformed filenames with real sanitization', (done) => {
|
||||
const malformedFile = {
|
||||
...mockFile,
|
||||
originalname: null, // This should be handled gracefully
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
// The function should handle this gracefully
|
||||
expect(typeof err === 'object' || err === null).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
try {
|
||||
storage.getFilename(mockReq, malformedFile, cb);
|
||||
} catch (error) {
|
||||
// If it throws, that's also acceptable behavior
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle edge cases in filename sanitization', (done) => {
|
||||
const edgeCaseFiles = [
|
||||
{ originalname: '', expected: /_/ },
|
||||
{ originalname: '.hidden', expected: /^_\.hidden/ },
|
||||
{ originalname: '../../../etc/passwd', expected: /passwd/ },
|
||||
{ originalname: 'file\x00name.txt', expected: /file_name\.txt/ },
|
||||
];
|
||||
|
||||
let testCount = 0;
|
||||
|
||||
const testNextFile = (fileData) => {
|
||||
const fileToTest = { ...mockFile, originalname: fileData.originalname };
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(filename).toMatch(fileData.expected);
|
||||
|
||||
testCount++;
|
||||
if (testCount === edgeCaseFiles.length) {
|
||||
done();
|
||||
} else {
|
||||
testNextFile(edgeCaseFiles[testCount]);
|
||||
}
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, fileToTest, cb);
|
||||
};
|
||||
|
||||
testNextFile(edgeCaseFiles[0]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Real Configuration Testing', () => {
|
||||
it('should handle missing custom config gracefully with real mergeFileConfig', async () => {
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
|
||||
// Mock getCustomConfig to return undefined
|
||||
getCustomConfig.mockResolvedValueOnce(undefined);
|
||||
|
||||
const multerInstance = await createMulterInstance();
|
||||
expect(multerInstance).toBeDefined();
|
||||
expect(typeof multerInstance.single).toBe('function');
|
||||
});
|
||||
|
||||
it('should properly integrate real fileConfig with custom endpoints', async () => {
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
|
||||
// Mock a custom config with additional endpoints
|
||||
getCustomConfig.mockResolvedValueOnce({
|
||||
fileConfig: {
|
||||
endpoints: {
|
||||
anthropic: {
|
||||
supportedMimeTypes: ['text/plain', 'image/png'],
|
||||
},
|
||||
},
|
||||
serverFileSizeLimit: 20, // 20 MB
|
||||
},
|
||||
});
|
||||
|
||||
const multerInstance = await createMulterInstance();
|
||||
expect(multerInstance).toBeDefined();
|
||||
|
||||
// Verify that getCustomConfig was called (we can't spy on the actual merge function easily)
|
||||
expect(getCustomConfig).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -4,6 +4,7 @@ const tokenizer = require('./tokenizer');
|
|||
const endpoints = require('./endpoints');
|
||||
const staticRoute = require('./static');
|
||||
const messages = require('./messages');
|
||||
const memories = require('./memories');
|
||||
const presets = require('./presets');
|
||||
const prompts = require('./prompts');
|
||||
const balance = require('./balance');
|
||||
|
|
@ -51,6 +52,7 @@ module.exports = {
|
|||
presets,
|
||||
balance,
|
||||
messages,
|
||||
memories,
|
||||
endpoints,
|
||||
tokenizer,
|
||||
assistants,
|
||||
|
|
|
|||
231
api/server/routes/memories.js
Normal file
231
api/server/routes/memories.js
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
const express = require('express');
|
||||
const { Tokenizer } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const {
|
||||
getAllUserMemories,
|
||||
toggleUserMemories,
|
||||
createMemory,
|
||||
setMemory,
|
||||
deleteMemory,
|
||||
} = require('~/models');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.READ,
|
||||
]);
|
||||
const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.UPDATE,
|
||||
]);
|
||||
const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.UPDATE,
|
||||
]);
|
||||
const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.OPT_OUT,
|
||||
]);
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
|
||||
/**
|
||||
* GET /memories
|
||||
* Returns all memories for the authenticated user, sorted by updated_at (newest first).
|
||||
* Also includes memory usage percentage based on token limit.
|
||||
*/
|
||||
router.get('/', checkMemoryRead, async (req, res) => {
|
||||
try {
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
|
||||
const sortedMemories = memories.sort(
|
||||
(a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime(),
|
||||
);
|
||||
|
||||
const totalTokens = memories.reduce((sum, memory) => {
|
||||
return sum + (memory.tokenCount || 0);
|
||||
}, 0);
|
||||
|
||||
const memoryConfig = req.app.locals?.memory;
|
||||
const tokenLimit = memoryConfig?.tokenLimit;
|
||||
|
||||
let usagePercentage = null;
|
||||
if (tokenLimit && tokenLimit > 0) {
|
||||
usagePercentage = Math.min(100, Math.round((totalTokens / tokenLimit) * 100));
|
||||
}
|
||||
|
||||
res.json({
|
||||
memories: sortedMemories,
|
||||
totalTokens,
|
||||
tokenLimit: tokenLimit || null,
|
||||
usagePercentage,
|
||||
});
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* POST /memories
|
||||
* Creates a new memory entry for the authenticated user.
|
||||
* Body: { key: string, value: string }
|
||||
* Returns 201 and { created: true, memory: <createdDoc> } when successful.
|
||||
*/
|
||||
router.post('/', checkMemoryCreate, async (req, res) => {
|
||||
const { key, value } = req.body;
|
||||
|
||||
if (typeof key !== 'string' || key.trim() === '') {
|
||||
return res.status(400).json({ error: 'Key is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
if (typeof value !== 'string' || value.trim() === '') {
|
||||
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
try {
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
|
||||
// Check token limit
|
||||
const memoryConfig = req.app.locals?.memory;
|
||||
const tokenLimit = memoryConfig?.tokenLimit;
|
||||
|
||||
if (tokenLimit) {
|
||||
const currentTotalTokens = memories.reduce(
|
||||
(sum, memory) => sum + (memory.tokenCount || 0),
|
||||
0,
|
||||
);
|
||||
if (currentTotalTokens + tokenCount > tokenLimit) {
|
||||
return res.status(400).json({
|
||||
error: `Adding this memory would exceed the token limit of ${tokenLimit}. Current usage: ${currentTotalTokens} tokens.`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const result = await createMemory({
|
||||
userId: req.user.id,
|
||||
key: key.trim(),
|
||||
value: value.trim(),
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to create memory.' });
|
||||
}
|
||||
|
||||
const updatedMemories = await getAllUserMemories(req.user.id);
|
||||
const newMemory = updatedMemories.find((m) => m.key === key.trim());
|
||||
|
||||
res.status(201).json({ created: true, memory: newMemory });
|
||||
} catch (error) {
|
||||
if (error.message && error.message.includes('already exists')) {
|
||||
return res.status(409).json({ error: 'Memory with this key already exists.' });
|
||||
}
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* PATCH /memories/preferences
|
||||
* Updates the user's memory preferences (e.g., enabling/disabling memories).
|
||||
* Body: { memories: boolean }
|
||||
* Returns 200 and { updated: true, preferences: { memories: boolean } } when successful.
|
||||
*/
|
||||
router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
|
||||
const { memories } = req.body;
|
||||
|
||||
if (typeof memories !== 'boolean') {
|
||||
return res.status(400).json({ error: 'memories must be a boolean value.' });
|
||||
}
|
||||
|
||||
try {
|
||||
const updatedUser = await toggleUserMemories(req.user.id, memories);
|
||||
|
||||
if (!updatedUser) {
|
||||
return res.status(404).json({ error: 'User not found.' });
|
||||
}
|
||||
|
||||
res.json({
|
||||
updated: true,
|
||||
preferences: {
|
||||
memories: updatedUser.personalization?.memories ?? true,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* PATCH /memories/:key
|
||||
* Updates the value of an existing memory entry for the authenticated user.
|
||||
* Body: { value: string }
|
||||
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
|
||||
*/
|
||||
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
|
||||
const { key } = req.params;
|
||||
const { value } = req.body || {};
|
||||
|
||||
if (typeof value !== 'string' || value.trim() === '') {
|
||||
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
try {
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
const existingMemory = memories.find((m) => m.key === key);
|
||||
|
||||
if (!existingMemory) {
|
||||
return res.status(404).json({ error: 'Memory not found.' });
|
||||
}
|
||||
|
||||
const result = await setMemory({
|
||||
userId: req.user.id,
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||
}
|
||||
|
||||
const updatedMemories = await getAllUserMemories(req.user.id);
|
||||
const updatedMemory = updatedMemories.find((m) => m.key === key);
|
||||
|
||||
res.json({ updated: true, memory: updatedMemory });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* DELETE /memories/:key
|
||||
* Deletes a memory entry for the authenticated user.
|
||||
* Returns 200 and { deleted: true } when successful.
|
||||
*/
|
||||
router.delete('/:key', checkMemoryDelete, async (req, res) => {
|
||||
const { key } = req.params;
|
||||
|
||||
try {
|
||||
const result = await deleteMemory({ userId: req.user.id, key });
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(404).json({ error: 'Memory not found.' });
|
||||
}
|
||||
|
||||
res.json({ deleted: true });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
const express = require('express');
|
||||
const {
|
||||
promptPermissionsSchema,
|
||||
memoryPermissionsSchema,
|
||||
agentPermissionsSchema,
|
||||
PermissionTypes,
|
||||
roleDefaults,
|
||||
|
|
@ -118,4 +119,43 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* PUT /api/roles/:roleName/memories
|
||||
* Update memory permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/memories', checkAdmin, async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['permissions']['MEMORIES']} */
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
const parsedUpdates = memoryPermissionsSchema.partial().parse(updates);
|
||||
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
role.permissions?.[PermissionTypes.MEMORIES] || role[PermissionTypes.MEMORIES] || {};
|
||||
|
||||
const mergedUpdates = {
|
||||
permissions: {
|
||||
...role.permissions,
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
...currentPermissions,
|
||||
...parsedUpdates,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: 'Invalid memory permissions.', error: error.errors });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
const jwt = require('jsonwebtoken');
|
||||
const { nanoid } = require('nanoid');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { GraphEvents, sleep } = require('@librechat/agents');
|
||||
const { sendEvent, logAxiosError } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
|
|
@ -13,11 +15,10 @@ const {
|
|||
actionDomainSeparator,
|
||||
} = require('librechat-data-provider');
|
||||
const { refreshAccessToken } = require('~/server/services/TokenService');
|
||||
const { logger, getFlowStateManager, sendEvent } = require('~/config');
|
||||
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
|
||||
const { getActions, deleteActions } = require('~/models/Action');
|
||||
const { deleteAssistant } = require('~/models/Assistant');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { findToken } = require('~/models');
|
||||
|
||||
|
|
@ -208,6 +209,7 @@ async function createActionTool({
|
|||
userId: userId,
|
||||
client_url: metadata.auth.client_url,
|
||||
redirect_uri: `${process.env.DOMAIN_SERVER}/api/actions/${action_id}/oauth/callback`,
|
||||
token_exchange_method: metadata.auth.token_exchange_method,
|
||||
/** Encrypted values */
|
||||
encrypted_oauth_client_id: encrypted.oauth_client_id,
|
||||
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
|
||||
|
|
@ -262,6 +264,7 @@ async function createActionTool({
|
|||
refresh_token,
|
||||
client_url: metadata.auth.client_url,
|
||||
encrypted_oauth_client_id: encrypted.oauth_client_id,
|
||||
token_exchange_method: metadata.auth.token_exchange_method,
|
||||
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
|
||||
});
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ const {
|
|||
loadOCRConfig,
|
||||
processMCPEnv,
|
||||
EModelEndpoint,
|
||||
loadMemoryConfig,
|
||||
getConfigDefaults,
|
||||
loadWebSearchConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { agentsConfigSetup } = require('@librechat/api');
|
||||
const {
|
||||
checkHealth,
|
||||
checkConfig,
|
||||
|
|
@ -24,7 +26,6 @@ const { azureConfigSetup } = require('./start/azureOpenAI');
|
|||
const { processModelSpecs } = require('./start/modelSpecs');
|
||||
const { initializeS3 } = require('./Files/S3/initialize');
|
||||
const { loadAndFormatTools } = require('./ToolService');
|
||||
const { agentsConfigSetup } = require('./start/agents');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { initializeRoles } = require('~/models');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
|
@ -44,6 +45,7 @@ const AppService = async (app) => {
|
|||
const ocr = loadOCRConfig(config.ocr);
|
||||
const webSearch = loadWebSearchConfig(config.webSearch);
|
||||
checkWebSearchConfig(webSearch);
|
||||
const memory = loadMemoryConfig(config.memory);
|
||||
const filteredTools = config.filteredTools;
|
||||
const includedTools = config.includedTools;
|
||||
const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy;
|
||||
|
|
@ -88,6 +90,7 @@ const AppService = async (app) => {
|
|||
const defaultLocals = {
|
||||
ocr,
|
||||
paths,
|
||||
memory,
|
||||
webSearch,
|
||||
fileStrategy,
|
||||
socialLogins,
|
||||
|
|
@ -100,8 +103,13 @@ const AppService = async (app) => {
|
|||
balance,
|
||||
};
|
||||
|
||||
const agentsDefaults = agentsConfigSetup(config);
|
||||
|
||||
if (!Object.keys(config).length) {
|
||||
app.locals = defaultLocals;
|
||||
app.locals = {
|
||||
...defaultLocals,
|
||||
[EModelEndpoint.agents]: agentsDefaults,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -136,9 +144,7 @@ const AppService = async (app) => {
|
|||
);
|
||||
}
|
||||
|
||||
if (endpoints?.[EModelEndpoint.agents]) {
|
||||
endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config);
|
||||
}
|
||||
endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config, agentsDefaults);
|
||||
|
||||
const endpointKeys = [
|
||||
EModelEndpoint.openAI,
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ const {
|
|||
FileSources,
|
||||
EModelEndpoint,
|
||||
EImageOutputType,
|
||||
AgentCapabilities,
|
||||
defaultSocialLogins,
|
||||
validateAzureGroups,
|
||||
defaultAgentCapabilities,
|
||||
deprecatedAzureVariables,
|
||||
conflictingAzureVariables,
|
||||
} = require('librechat-data-provider');
|
||||
|
|
@ -151,6 +153,11 @@ describe('AppService', () => {
|
|||
safeSearch: 1,
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
},
|
||||
memory: undefined,
|
||||
agents: {
|
||||
disableBuilder: false,
|
||||
capabilities: expect.arrayContaining([...defaultAgentCapabilities]),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -268,6 +275,71 @@ describe('AppService', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should correctly configure Agents endpoint based on custom config', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
[EModelEndpoint.agents]: {
|
||||
disableBuilder: true,
|
||||
recursionLimit: 10,
|
||||
maxRecursionLimit: 20,
|
||||
allowedProviders: ['openai', 'anthropic'],
|
||||
capabilities: [AgentCapabilities.tools, AgentCapabilities.actions],
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
|
||||
expect(app.locals[EModelEndpoint.agents]).toEqual(
|
||||
expect.objectContaining({
|
||||
disableBuilder: true,
|
||||
recursionLimit: 10,
|
||||
maxRecursionLimit: 20,
|
||||
allowedProviders: expect.arrayContaining(['openai', 'anthropic']),
|
||||
capabilities: expect.arrayContaining([AgentCapabilities.tools, AgentCapabilities.actions]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should configure Agents endpoint with defaults when no config is provided', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({}));
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
|
||||
expect(app.locals[EModelEndpoint.agents]).toEqual(
|
||||
expect.objectContaining({
|
||||
disableBuilder: false,
|
||||
capabilities: expect.arrayContaining([...defaultAgentCapabilities]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should configure Agents endpoint with defaults when endpoints exist but agents is not defined', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleConvo: true,
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
|
||||
expect(app.locals[EModelEndpoint.agents]).toEqual(
|
||||
expect.objectContaining({
|
||||
disableBuilder: false,
|
||||
capabilities: expect.arrayContaining([...defaultAgentCapabilities]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly configure minimum Azure OpenAI Assistant values', async () => {
|
||||
const assistantGroups = [azureGroups[0], { ...azureGroups[1], assistants: true }];
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
|
|
|
|||
196
api/server/services/Endpoints/agents/agent.js
Normal file
196
api/server/services/Endpoints/agents/agent.js
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
const { Providers } = require('@librechat/agents');
|
||||
const { primeResources, optionalChainWithEmptyCheck } = require('@librechat/api');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
EToolResources,
|
||||
replaceSpecialVars,
|
||||
providerEndpointMap,
|
||||
} = require('librechat-data-provider');
|
||||
const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize');
|
||||
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
|
||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
||||
const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getToolFilesByIds } = require('~/models/File');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
[Providers.DEEPSEEK]: initCustom,
|
||||
[Providers.OPENROUTER]: initCustom,
|
||||
[EModelEndpoint.openAI]: initOpenAI,
|
||||
[EModelEndpoint.google]: initGoogle,
|
||||
[EModelEndpoint.azureOpenAI]: initOpenAI,
|
||||
[EModelEndpoint.anthropic]: initAnthropic,
|
||||
[EModelEndpoint.bedrock]: getBedrockOptions,
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Agent} params.agent
|
||||
* @param {string | null} [params.conversationId]
|
||||
* @param {Array<IMongoFile>} [params.requestFiles]
|
||||
* @param {typeof import('~/server/services/ToolService').loadAgentTools | undefined} [params.loadTools]
|
||||
* @param {TEndpointOption} [params.endpointOption]
|
||||
* @param {Set<string>} [params.allowedProviders]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent & { tools: StructuredTool[], attachments: Array<MongoFile>, toolContextMap: Record<string, unknown>, maxContextTokens: number }>}
|
||||
*/
|
||||
const initializeAgent = async ({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
|
||||
throw new Error(
|
||||
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
|
||||
);
|
||||
}
|
||||
let currentFiles;
|
||||
|
||||
if (
|
||||
isInitialAgent &&
|
||||
conversationId != null &&
|
||||
(agent.model_parameters?.resendFiles ?? true) === true
|
||||
) {
|
||||
const fileIds = (await getConvoFiles(conversationId)) ?? [];
|
||||
/** @type {Set<EToolResources>} */
|
||||
const toolResourceSet = new Set();
|
||||
for (const tool of agent.tools) {
|
||||
if (EToolResources[tool]) {
|
||||
toolResourceSet.add(EToolResources[tool]);
|
||||
}
|
||||
}
|
||||
const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet);
|
||||
if (requestFiles.length || toolFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles.concat(toolFiles));
|
||||
}
|
||||
} else if (isInitialAgent && requestFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles);
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources({
|
||||
req,
|
||||
getFiles,
|
||||
attachments: currentFiles,
|
||||
tool_resources: agent.tool_resources,
|
||||
requestFileSet: new Set(requestFiles?.map((file) => file.file_id)),
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const { tools, toolContextMap } =
|
||||
(await loadTools?.({
|
||||
req,
|
||||
res,
|
||||
provider,
|
||||
agentId: agent.id,
|
||||
tools: agent.tools,
|
||||
model: agent.model,
|
||||
tool_resources,
|
||||
})) ?? {};
|
||||
|
||||
agent.endpoint = provider;
|
||||
let getOptions = providerConfigMap[provider];
|
||||
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
|
||||
agent.provider = provider.toLowerCase();
|
||||
getOptions = providerConfigMap[agent.provider];
|
||||
} else if (!getOptions) {
|
||||
const customEndpointConfig = await getCustomEndpointConfig(provider);
|
||||
if (!customEndpointConfig) {
|
||||
throw new Error(`Provider ${provider} not supported`);
|
||||
}
|
||||
getOptions = initCustom;
|
||||
agent.provider = Providers.OPENAI;
|
||||
}
|
||||
const model_parameters = Object.assign(
|
||||
{},
|
||||
agent.model_parameters ?? { model: agent.model },
|
||||
isInitialAgent === true ? endpointOption?.model_parameters : {},
|
||||
);
|
||||
const _endpointOption =
|
||||
isInitialAgent === true
|
||||
? Object.assign({}, endpointOption, { model_parameters })
|
||||
: { model_parameters };
|
||||
|
||||
const options = await getOptions({
|
||||
req,
|
||||
res,
|
||||
optionsOnly: true,
|
||||
overrideEndpoint: provider,
|
||||
overrideModel: agent.model,
|
||||
endpointOption: _endpointOption,
|
||||
});
|
||||
|
||||
if (
|
||||
agent.endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName == null
|
||||
) {
|
||||
agent.provider = Providers.OPENAI;
|
||||
}
|
||||
|
||||
if (options.provider != null) {
|
||||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
|
||||
if (options.configOptions) {
|
||||
agent.model_parameters.configuration = options.configOptions;
|
||||
}
|
||||
|
||||
if (!agent.model_parameters.model) {
|
||||
agent.model_parameters.model = agent.model;
|
||||
}
|
||||
|
||||
if (agent.instructions && agent.instructions !== '') {
|
||||
agent.instructions = replaceSpecialVars({
|
||||
text: agent.instructions,
|
||||
user: req.user,
|
||||
});
|
||||
}
|
||||
|
||||
if (typeof agent.artifacts === 'string' && agent.artifacts !== '') {
|
||||
agent.additional_instructions = generateArtifactsPrompt({
|
||||
endpoint: agent.provider,
|
||||
artifacts: agent.artifacts,
|
||||
});
|
||||
}
|
||||
|
||||
const tokensModel =
|
||||
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model;
|
||||
const maxTokens = optionalChainWithEmptyCheck(
|
||||
agent.model_parameters.maxOutputTokens,
|
||||
agent.model_parameters.maxTokens,
|
||||
0,
|
||||
);
|
||||
const maxContextTokens = optionalChainWithEmptyCheck(
|
||||
agent.model_parameters.maxContextTokens,
|
||||
agent.max_context_tokens,
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
|
||||
4096,
|
||||
);
|
||||
return {
|
||||
...agent,
|
||||
tools,
|
||||
attachments,
|
||||
toolContextMap,
|
||||
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = { initializeAgent };
|
||||
|
|
@ -1,294 +1,41 @@
|
|||
const { createContentAggregator, Providers } = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
EToolResources,
|
||||
getResponseSender,
|
||||
AgentCapabilities,
|
||||
replaceSpecialVars,
|
||||
providerEndpointMap,
|
||||
} = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createContentAggregator } = require('@librechat/agents');
|
||||
const { Constants, EModelEndpoint, getResponseSender } = require('librechat-data-provider');
|
||||
const {
|
||||
getDefaultHandlers,
|
||||
createToolEndCallback,
|
||||
} = require('~/server/controllers/agents/callbacks');
|
||||
const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize');
|
||||
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
|
||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
||||
const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getToolFilesByIds } = require('~/models/File');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
[Providers.DEEPSEEK]: initCustom,
|
||||
[Providers.OPENROUTER]: initCustom,
|
||||
[EModelEndpoint.openAI]: initOpenAI,
|
||||
[EModelEndpoint.google]: initGoogle,
|
||||
[EModelEndpoint.azureOpenAI]: initOpenAI,
|
||||
[EModelEndpoint.anthropic]: initAnthropic,
|
||||
[EModelEndpoint.bedrock]: getBedrockOptions,
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {Promise<Array<MongoFile | null>> | undefined} [params.attachments]
|
||||
* @param {Set<string>} params.requestFileSet
|
||||
* @param {AgentToolResources | undefined} [params.tool_resources]
|
||||
* @returns {Promise<{ attachments: Array<MongoFile | undefined> | undefined, tool_resources: AgentToolResources | undefined }>}
|
||||
*/
|
||||
const primeResources = async ({
|
||||
req,
|
||||
attachments: _attachments,
|
||||
tool_resources: _tool_resources,
|
||||
requestFileSet,
|
||||
}) => {
|
||||
try {
|
||||
/** @type {Array<MongoFile | undefined> | undefined} */
|
||||
let attachments;
|
||||
const tool_resources = _tool_resources ?? {};
|
||||
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
|
||||
AgentCapabilities.ocr,
|
||||
);
|
||||
if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) {
|
||||
const context = await getFiles(
|
||||
{
|
||||
file_id: { $in: tool_resources.ocr.file_ids },
|
||||
},
|
||||
{},
|
||||
{},
|
||||
);
|
||||
attachments = (attachments ?? []).concat(context);
|
||||
function createToolLoader() {
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {string} params.agentId
|
||||
* @param {string[]} params.tools
|
||||
* @param {string} params.provider
|
||||
* @param {string} params.model
|
||||
* @param {AgentToolResources} params.tool_resources
|
||||
* @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record<string, unknown> } | undefined>}
|
||||
*/
|
||||
return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) {
|
||||
const agent = { id: agentId, tools, provider, model };
|
||||
try {
|
||||
return await loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
tool_resources,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error loading tools for agent ' + agentId, error);
|
||||
}
|
||||
if (!_attachments) {
|
||||
return { attachments, tool_resources };
|
||||
}
|
||||
/** @type {Array<MongoFile | undefined> | undefined} */
|
||||
const files = await _attachments;
|
||||
if (!attachments) {
|
||||
/** @type {Array<MongoFile | undefined>} */
|
||||
attachments = [];
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
if (!file) {
|
||||
continue;
|
||||
}
|
||||
if (file.metadata?.fileIdentifier) {
|
||||
const execute_code = tool_resources[EToolResources.execute_code] ?? {};
|
||||
if (!execute_code.files) {
|
||||
tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.execute_code].files.push(file);
|
||||
} else if (file.embedded === true) {
|
||||
const file_search = tool_resources[EToolResources.file_search] ?? {};
|
||||
if (!file_search.files) {
|
||||
tool_resources[EToolResources.file_search] = { ...file_search, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.file_search].files.push(file);
|
||||
} else if (
|
||||
requestFileSet.has(file.file_id) &&
|
||||
file.type.startsWith('image') &&
|
||||
file.height &&
|
||||
file.width
|
||||
) {
|
||||
const image_edit = tool_resources[EToolResources.image_edit] ?? {};
|
||||
if (!image_edit.files) {
|
||||
tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.image_edit].files.push(file);
|
||||
}
|
||||
|
||||
attachments.push(file);
|
||||
}
|
||||
return { attachments, tool_resources };
|
||||
} catch (error) {
|
||||
logger.error('Error priming resources', error);
|
||||
return { attachments: _attachments, tool_resources: _tool_resources };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {...string | number} values
|
||||
* @returns {string | number | undefined}
|
||||
*/
|
||||
function optionalChainWithEmptyCheck(...values) {
|
||||
for (const value of values) {
|
||||
if (value !== undefined && value !== null && value !== '') {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return values[values.length - 1];
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Agent} params.agent
|
||||
* @param {Set<string>} [params.allowedProviders]
|
||||
* @param {object} [params.endpointOption]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent>}
|
||||
*/
|
||||
const initializeAgentOptions = async ({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
|
||||
throw new Error(
|
||||
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
|
||||
);
|
||||
}
|
||||
let currentFiles;
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
if (
|
||||
isInitialAgent &&
|
||||
req.body.conversationId != null &&
|
||||
(agent.model_parameters?.resendFiles ?? true) === true
|
||||
) {
|
||||
const fileIds = (await getConvoFiles(req.body.conversationId)) ?? [];
|
||||
/** @type {Set<EToolResources>} */
|
||||
const toolResourceSet = new Set();
|
||||
for (const tool of agent.tools) {
|
||||
if (EToolResources[tool]) {
|
||||
toolResourceSet.add(EToolResources[tool]);
|
||||
}
|
||||
}
|
||||
const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet);
|
||||
if (requestFiles.length || toolFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles.concat(toolFiles));
|
||||
}
|
||||
} else if (isInitialAgent && requestFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles);
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources({
|
||||
req,
|
||||
attachments: currentFiles,
|
||||
tool_resources: agent.tool_resources,
|
||||
requestFileSet: new Set(requestFiles.map((file) => file.file_id)),
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const { tools, toolContextMap } = await loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
agent: {
|
||||
id: agent.id,
|
||||
tools: agent.tools,
|
||||
provider,
|
||||
model: agent.model,
|
||||
},
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
agent.endpoint = provider;
|
||||
let getOptions = providerConfigMap[provider];
|
||||
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
|
||||
agent.provider = provider.toLowerCase();
|
||||
getOptions = providerConfigMap[agent.provider];
|
||||
} else if (!getOptions) {
|
||||
const customEndpointConfig = await getCustomEndpointConfig(provider);
|
||||
if (!customEndpointConfig) {
|
||||
throw new Error(`Provider ${provider} not supported`);
|
||||
}
|
||||
getOptions = initCustom;
|
||||
agent.provider = Providers.OPENAI;
|
||||
}
|
||||
const model_parameters = Object.assign(
|
||||
{},
|
||||
agent.model_parameters ?? { model: agent.model },
|
||||
isInitialAgent === true ? endpointOption?.model_parameters : {},
|
||||
);
|
||||
const _endpointOption =
|
||||
isInitialAgent === true
|
||||
? Object.assign({}, endpointOption, { model_parameters })
|
||||
: { model_parameters };
|
||||
|
||||
const options = await getOptions({
|
||||
req,
|
||||
res,
|
||||
optionsOnly: true,
|
||||
overrideEndpoint: provider,
|
||||
overrideModel: agent.model,
|
||||
endpointOption: _endpointOption,
|
||||
});
|
||||
|
||||
if (
|
||||
agent.endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName == null
|
||||
) {
|
||||
agent.provider = Providers.OPENAI;
|
||||
}
|
||||
|
||||
if (options.provider != null) {
|
||||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
|
||||
if (options.configOptions) {
|
||||
agent.model_parameters.configuration = options.configOptions;
|
||||
}
|
||||
|
||||
if (!agent.model_parameters.model) {
|
||||
agent.model_parameters.model = agent.model;
|
||||
}
|
||||
|
||||
if (agent.instructions && agent.instructions !== '') {
|
||||
agent.instructions = replaceSpecialVars({
|
||||
text: agent.instructions,
|
||||
user: req.user,
|
||||
});
|
||||
}
|
||||
|
||||
if (typeof agent.artifacts === 'string' && agent.artifacts !== '') {
|
||||
agent.additional_instructions = generateArtifactsPrompt({
|
||||
endpoint: agent.provider,
|
||||
artifacts: agent.artifacts,
|
||||
});
|
||||
}
|
||||
|
||||
const tokensModel =
|
||||
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model;
|
||||
const maxTokens = optionalChainWithEmptyCheck(
|
||||
agent.model_parameters.maxOutputTokens,
|
||||
agent.model_parameters.maxTokens,
|
||||
0,
|
||||
);
|
||||
const maxContextTokens = optionalChainWithEmptyCheck(
|
||||
agent.model_parameters.maxContextTokens,
|
||||
agent.max_context_tokens,
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
|
||||
4096,
|
||||
);
|
||||
return {
|
||||
...agent,
|
||||
tools,
|
||||
attachments,
|
||||
toolContextMap,
|
||||
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
if (!endpointOption) {
|
||||
|
|
@ -313,7 +60,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
throw new Error('No agent promise provided');
|
||||
}
|
||||
|
||||
// Initialize primary agent
|
||||
const primaryAgent = await endpointOption.agent;
|
||||
if (!primaryAgent) {
|
||||
throw new Error('Agent not found');
|
||||
|
|
@ -323,10 +69,18 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
/** @type {Set<string>} */
|
||||
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
||||
// Handle primary agent
|
||||
const primaryConfig = await initializeAgentOptions({
|
||||
const loadTools = createToolLoader();
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
/** @type {string} */
|
||||
const conversationId = req.body.conversationId;
|
||||
|
||||
const primaryConfig = await initializeAgent({
|
||||
req,
|
||||
res,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
agent: primaryAgent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
|
|
@ -340,10 +94,13 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
if (!agent) {
|
||||
throw new Error(`Agent ${agentId} not found`);
|
||||
}
|
||||
const config = await initializeAgentOptions({
|
||||
const config = await initializeAgent({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { constructAzureURL, isUserProvided } = require('@librechat/api');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
|
|
@ -12,8 +13,6 @@ const {
|
|||
checkUserKeyExpiry,
|
||||
} = require('~/server/services/UserService');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const { constructAzureURL } = require('~/utils');
|
||||
|
||||
class Files {
|
||||
constructor(client) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { createHandleLLMNewToken } = require('@librechat/api');
|
||||
const {
|
||||
AuthType,
|
||||
Constants,
|
||||
|
|
@ -8,7 +9,6 @@ const {
|
|||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { createHandleLLMNewToken } = require('~/app/clients/generators');
|
||||
|
||||
const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
||||
const {
|
||||
|
|
|
|||
|
|
@ -6,10 +6,9 @@ const {
|
|||
extractEnvVariable,
|
||||
} = require('librechat-data-provider');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { getOpenAIConfig, createHandleLLMNewToken } = require('@librechat/api');
|
||||
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { createHandleLLMNewToken } = require('~/app/clients/generators');
|
||||
const { fetchModels } = require('~/server/services/ModelService');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
|
|
@ -144,7 +143,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
|||
clientOptions,
|
||||
);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getLLMConfig(apiKey, clientOptions, endpoint);
|
||||
const options = getOpenAIConfig(apiKey, clientOptions, endpoint);
|
||||
if (!customOptions.streamRate) {
|
||||
return options;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,9 +25,9 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
|||
const credentials = isUserProvided
|
||||
? userKey
|
||||
: {
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
|
||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||
};
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
|
||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||
};
|
||||
|
||||
let clientOptions = {};
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ function getLLMConfig(credentials, options = {}) {
|
|||
// Extract from credentials
|
||||
const serviceKeyRaw = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
|
||||
const serviceKey =
|
||||
typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : serviceKeyRaw ?? {};
|
||||
typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : (serviceKeyRaw ?? {});
|
||||
|
||||
const project_id = serviceKey?.project_id ?? null;
|
||||
const apiKey = creds[AuthKeys.GOOGLE_API_KEY] ?? null;
|
||||
|
|
@ -156,10 +156,6 @@ function getLLMConfig(credentials, options = {}) {
|
|||
}
|
||||
|
||||
if (authHeader) {
|
||||
/**
|
||||
* NOTE: NOT SUPPORTED BY LANGCHAIN GENAI CLIENT,
|
||||
* REQUIRES PR IN https://github.com/langchain-ai/langchainjs
|
||||
*/
|
||||
llmConfig.customHeaders = {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
const {
|
||||
EModelEndpoint,
|
||||
mapModelToAzureConfig,
|
||||
resolveHeaders,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { isEnabled, isUserProvided, getAzureCredentials } = require('@librechat/api');
|
||||
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { isEnabled, isUserProvided } = require('~/server/utils');
|
||||
const { getAzureCredentials } = require('~/utils');
|
||||
const { PluginsClient } = require('~/app');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
|
|
|
|||
|
|
@ -114,11 +114,11 @@ describe('gptPlugins/initializeClient', () => {
|
|||
test('should initialize PluginsClient with Azure credentials when PLUGINS_USE_AZURE is true', async () => {
|
||||
process.env.AZURE_API_KEY = 'test-azure-api-key';
|
||||
(process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_VERSION = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.PLUGINS_USE_AZURE = 'true');
|
||||
(process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_VERSION = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.PLUGINS_USE_AZURE = 'true');
|
||||
process.env.DEBUG_PLUGINS = 'false';
|
||||
process.env.OPENAI_SUMMARIZE = 'false';
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@ const {
|
|||
resolveHeaders,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
isEnabled,
|
||||
isUserProvided,
|
||||
getOpenAIConfig,
|
||||
getAzureCredentials,
|
||||
createHandleLLMNewToken,
|
||||
} = require('@librechat/api');
|
||||
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
|
||||
const { createHandleLLMNewToken } = require('~/app/clients/generators');
|
||||
const { isEnabled, isUserProvided } = require('~/server/utils');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { getAzureCredentials } = require('~/utils');
|
||||
|
||||
const initializeClient = async ({
|
||||
req,
|
||||
|
|
@ -140,7 +143,7 @@ const initializeClient = async ({
|
|||
modelOptions.model = modelName;
|
||||
clientOptions = Object.assign({ modelOptions }, clientOptions);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getLLMConfig(apiKey, clientOptions);
|
||||
const options = getOpenAIConfig(apiKey, clientOptions);
|
||||
const streamRate = clientOptions.streamRate;
|
||||
if (!streamRate) {
|
||||
return options;
|
||||
|
|
|
|||
|
|
@ -1,170 +0,0 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { KnownEndpoints } = require('librechat-data-provider');
|
||||
const { sanitizeModelName, constructAzureURL } = require('~/utils');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
/**
|
||||
* Generates configuration options for creating a language model (LLM) instance.
|
||||
* @param {string} apiKey - The API key for authentication.
|
||||
* @param {Object} options - Additional options for configuring the LLM.
|
||||
* @param {Object} [options.modelOptions] - Model-specific options.
|
||||
* @param {string} [options.modelOptions.model] - The name of the model to use.
|
||||
* @param {string} [options.modelOptions.user] - The user ID
|
||||
* @param {number} [options.modelOptions.temperature] - Controls randomness in output generation (0-2).
|
||||
* @param {number} [options.modelOptions.top_p] - Controls diversity via nucleus sampling (0-1).
|
||||
* @param {number} [options.modelOptions.frequency_penalty] - Reduces repetition of token sequences (-2 to 2).
|
||||
* @param {number} [options.modelOptions.presence_penalty] - Encourages discussing new topics (-2 to 2).
|
||||
* @param {number} [options.modelOptions.max_tokens] - The maximum number of tokens to generate.
|
||||
* @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens.
|
||||
* @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used.
|
||||
* @param {boolean} [options.useOpenRouter] - Flag to use OpenRouter API.
|
||||
* @param {Object} [options.headers] - Additional headers for API requests.
|
||||
* @param {string} [options.proxy] - Proxy server URL.
|
||||
* @param {Object} [options.azure] - Azure-specific configurations.
|
||||
* @param {boolean} [options.streaming] - Whether to use streaming mode.
|
||||
* @param {Object} [options.addParams] - Additional parameters to add to the model options.
|
||||
* @param {string[]} [options.dropParams] - Parameters to remove from the model options.
|
||||
* @param {string|null} [endpoint=null] - The endpoint name
|
||||
* @returns {Object} Configuration options for creating an LLM instance.
|
||||
*/
|
||||
function getLLMConfig(apiKey, options = {}, endpoint = null) {
|
||||
let {
|
||||
modelOptions = {},
|
||||
reverseProxyUrl,
|
||||
defaultQuery,
|
||||
headers,
|
||||
proxy,
|
||||
azure,
|
||||
streaming = true,
|
||||
addParams,
|
||||
dropParams,
|
||||
} = options;
|
||||
|
||||
/** @type {OpenAIClientOptions} */
|
||||
let llmConfig = {
|
||||
streaming,
|
||||
};
|
||||
|
||||
Object.assign(llmConfig, modelOptions);
|
||||
|
||||
if (addParams && typeof addParams === 'object') {
|
||||
Object.assign(llmConfig, addParams);
|
||||
}
|
||||
/** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */
|
||||
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
|
||||
const searchExcludeParams = [
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'stop',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'response_format',
|
||||
'n',
|
||||
'logprobs',
|
||||
'user',
|
||||
];
|
||||
|
||||
dropParams = dropParams || [];
|
||||
dropParams = [...new Set([...dropParams, ...searchExcludeParams])];
|
||||
}
|
||||
|
||||
if (dropParams && Array.isArray(dropParams)) {
|
||||
dropParams.forEach((param) => {
|
||||
if (llmConfig[param]) {
|
||||
llmConfig[param] = undefined;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let useOpenRouter;
|
||||
/** @type {OpenAIClientOptions['configuration']} */
|
||||
const configOptions = {};
|
||||
if (
|
||||
(reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) ||
|
||||
(endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
useOpenRouter = true;
|
||||
llmConfig.include_reasoning = true;
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
configOptions.defaultHeaders = Object.assign(
|
||||
{
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
},
|
||||
headers,
|
||||
);
|
||||
} else if (reverseProxyUrl) {
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
if (headers) {
|
||||
configOptions.defaultHeaders = headers;
|
||||
}
|
||||
}
|
||||
|
||||
if (defaultQuery) {
|
||||
configOptions.defaultQuery = defaultQuery;
|
||||
}
|
||||
|
||||
if (proxy) {
|
||||
const proxyAgent = new HttpsProxyAgent(proxy);
|
||||
Object.assign(configOptions, {
|
||||
httpAgent: proxyAgent,
|
||||
httpsAgent: proxyAgent,
|
||||
});
|
||||
}
|
||||
|
||||
if (azure) {
|
||||
const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME);
|
||||
azure.azureOpenAIApiDeploymentName = useModelName
|
||||
? sanitizeModelName(llmConfig.model)
|
||||
: azure.azureOpenAIApiDeploymentName;
|
||||
|
||||
if (process.env.AZURE_OPENAI_DEFAULT_MODEL) {
|
||||
llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
}
|
||||
|
||||
if (configOptions.baseURL) {
|
||||
const azureURL = constructAzureURL({
|
||||
baseURL: configOptions.baseURL,
|
||||
azureOptions: azure,
|
||||
});
|
||||
azure.azureOpenAIBasePath = azureURL.split(`/${azure.azureOpenAIApiDeploymentName}`)[0];
|
||||
}
|
||||
|
||||
Object.assign(llmConfig, azure);
|
||||
llmConfig.model = llmConfig.azureOpenAIApiDeploymentName;
|
||||
} else {
|
||||
llmConfig.apiKey = apiKey;
|
||||
// Object.assign(llmConfig, {
|
||||
// configuration: { apiKey },
|
||||
// });
|
||||
}
|
||||
|
||||
if (process.env.OPENAI_ORGANIZATION && this.azure) {
|
||||
llmConfig.organization = process.env.OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
if (useOpenRouter && llmConfig.reasoning_effort != null) {
|
||||
llmConfig.reasoning = {
|
||||
effort: llmConfig.reasoning_effort,
|
||||
};
|
||||
delete llmConfig.reasoning_effort;
|
||||
}
|
||||
|
||||
if (llmConfig?.['max_tokens'] != null) {
|
||||
/** @type {number} */
|
||||
llmConfig.maxTokens = llmConfig['max_tokens'];
|
||||
delete llmConfig['max_tokens'];
|
||||
}
|
||||
|
||||
return {
|
||||
/** @type {OpenAIClientOptions} */
|
||||
llmConfig,
|
||||
/** @type {OpenAIClientOptions['configuration']} */
|
||||
configOptions,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = { getLLMConfig };
|
||||
|
|
@ -2,9 +2,9 @@ const axios = require('axios');
|
|||
const fs = require('fs').promises;
|
||||
const FormData = require('form-data');
|
||||
const { Readable } = require('stream');
|
||||
const { genAzureEndpoint } = require('@librechat/api');
|
||||
const { extractEnvVariable, STTProviders } = require('librechat-data-provider');
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
const { genAzureEndpoint } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
const axios = require('axios');
|
||||
const { genAzureEndpoint } = require('@librechat/api');
|
||||
const { extractEnvVariable, TTSProviders } = require('librechat-data-provider');
|
||||
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
const { genAzureEndpoint } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -91,24 +91,44 @@ async function prepareAzureImageURL(req, file) {
|
|||
* @param {Buffer} params.buffer - The avatar image buffer.
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} params.manual - Flag to indicate manual update.
|
||||
* @param {string} [params.agentId] - Optional agent ID if this is an agent avatar.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<string>} The URL of the avatar.
|
||||
*/
|
||||
async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) {
|
||||
async function processAzureAvatar({
|
||||
buffer,
|
||||
userId,
|
||||
manual,
|
||||
agentId,
|
||||
basePath = 'images',
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const metadata = await sharp(buffer).metadata();
|
||||
const extension = metadata.format === 'gif' ? 'gif' : 'png';
|
||||
const timestamp = new Date().getTime();
|
||||
|
||||
/** Unique filename with timestamp and optional agent ID */
|
||||
const fileName = agentId
|
||||
? `agent-${agentId}-avatar-${timestamp}.${extension}`
|
||||
: `avatar-${timestamp}.${extension}`;
|
||||
|
||||
const downloadURL = await saveBufferToAzure({
|
||||
userId,
|
||||
buffer,
|
||||
fileName: 'avatar.png',
|
||||
fileName,
|
||||
basePath,
|
||||
containerName,
|
||||
});
|
||||
const isManual = manual === 'true';
|
||||
const url = `${downloadURL}?manual=${isManual}`;
|
||||
if (isManual) {
|
||||
|
||||
// Only update user record if this is a user avatar (manual === 'true')
|
||||
if (isManual && !agentId) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
}
|
||||
|
||||
return url;
|
||||
} catch (error) {
|
||||
logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
const FormData = require('form-data');
|
||||
const { getCodeBaseURL } = require('@librechat/agents');
|
||||
const { createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { createAxiosInstance, logAxiosError } = require('@librechat/api');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
const path = require('path');
|
||||
const { v4 } = require('uuid');
|
||||
const axios = require('axios');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getCodeBaseURL } = require('@librechat/agents');
|
||||
const {
|
||||
Tools,
|
||||
|
|
@ -12,8 +14,6 @@ const {
|
|||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { convertImage } = require('~/server/services/Files/images/convert');
|
||||
const { createFile, getFiles, updateFile } = require('~/models/File');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Process OpenAI image files, convert to target format, save and return file metadata.
|
||||
|
|
|
|||
|
|
@ -82,22 +82,32 @@ async function prepareImageURL(req, file) {
|
|||
* @param {Buffer} params.buffer - The Buffer containing the avatar image.
|
||||
* @param {string} params.userId - The user ID.
|
||||
* @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false').
|
||||
* @param {string} [params.agentId] - Optional agent ID if this is an agent avatar.
|
||||
* @returns {Promise<string>} - A promise that resolves with the URL of the uploaded avatar.
|
||||
* @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading.
|
||||
*/
|
||||
async function processFirebaseAvatar({ buffer, userId, manual }) {
|
||||
async function processFirebaseAvatar({ buffer, userId, manual, agentId }) {
|
||||
try {
|
||||
const metadata = await sharp(buffer).metadata();
|
||||
const extension = metadata.format === 'gif' ? 'gif' : 'png';
|
||||
const timestamp = new Date().getTime();
|
||||
|
||||
/** Unique filename with timestamp and optional agent ID */
|
||||
const fileName = agentId
|
||||
? `agent-${agentId}-avatar-${timestamp}.${extension}`
|
||||
: `avatar-${timestamp}.${extension}`;
|
||||
|
||||
const downloadURL = await saveBufferToFirebase({
|
||||
userId,
|
||||
buffer,
|
||||
fileName: 'avatar.png',
|
||||
fileName,
|
||||
});
|
||||
|
||||
const isManual = manual === 'true';
|
||||
|
||||
const url = `${downloadURL}?manual=${isManual}`;
|
||||
|
||||
if (isManual) {
|
||||
// Only update user record if this is a user avatar (manual === 'true')
|
||||
if (isManual && !agentId) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -201,6 +201,10 @@ const unlinkFile = async (filepath) => {
|
|||
*/
|
||||
const deleteLocalFile = async (req, file) => {
|
||||
const { publicPath, uploads } = req.app.locals.paths;
|
||||
|
||||
/** Filepath stripped of query parameters (e.g., ?manual=true) */
|
||||
const cleanFilepath = file.filepath.split('?')[0];
|
||||
|
||||
if (file.embedded && process.env.RAG_API_URL) {
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
axios.delete(`${process.env.RAG_API_URL}/documents`, {
|
||||
|
|
@ -213,32 +217,32 @@ const deleteLocalFile = async (req, file) => {
|
|||
});
|
||||
}
|
||||
|
||||
if (file.filepath.startsWith(`/uploads/${req.user.id}`)) {
|
||||
if (cleanFilepath.startsWith(`/uploads/${req.user.id}`)) {
|
||||
const userUploadDir = path.join(uploads, req.user.id);
|
||||
const basePath = file.filepath.split(`/uploads/${req.user.id}/`)[1];
|
||||
const basePath = cleanFilepath.split(`/uploads/${req.user.id}/`)[1];
|
||||
|
||||
if (!basePath) {
|
||||
throw new Error(`Invalid file path: ${file.filepath}`);
|
||||
throw new Error(`Invalid file path: ${cleanFilepath}`);
|
||||
}
|
||||
|
||||
const filepath = path.join(userUploadDir, basePath);
|
||||
|
||||
const rel = path.relative(userUploadDir, filepath);
|
||||
if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) {
|
||||
throw new Error(`Invalid file path: ${file.filepath}`);
|
||||
throw new Error(`Invalid file path: ${cleanFilepath}`);
|
||||
}
|
||||
|
||||
await unlinkFile(filepath);
|
||||
return;
|
||||
}
|
||||
|
||||
const parts = file.filepath.split(path.sep);
|
||||
const parts = cleanFilepath.split(path.sep);
|
||||
const subfolder = parts[1];
|
||||
if (!subfolder && parts[0] === EModelEndpoint.agents) {
|
||||
logger.warn(`Agent File ${file.file_id} is missing filepath, may have been deleted already`);
|
||||
return;
|
||||
}
|
||||
const filepath = path.join(publicPath, file.filepath);
|
||||
const filepath = path.join(publicPath, cleanFilepath);
|
||||
|
||||
if (!isValidPath(req, publicPath, subfolder, filepath)) {
|
||||
throw new Error('Invalid file path');
|
||||
|
|
|
|||
|
|
@ -112,10 +112,11 @@ async function prepareImagesLocal(req, file) {
|
|||
* @param {Buffer} params.buffer - The Buffer containing the avatar image.
|
||||
* @param {string} params.userId - The user ID.
|
||||
* @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false').
|
||||
* @param {string} [params.agentId] - Optional agent ID if this is an agent avatar.
|
||||
* @returns {Promise<string>} - A promise that resolves with the URL of the uploaded avatar.
|
||||
* @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading.
|
||||
*/
|
||||
async function processLocalAvatar({ buffer, userId, manual }) {
|
||||
async function processLocalAvatar({ buffer, userId, manual, agentId }) {
|
||||
const userDir = path.resolve(
|
||||
__dirname,
|
||||
'..',
|
||||
|
|
@ -129,7 +130,14 @@ async function processLocalAvatar({ buffer, userId, manual }) {
|
|||
userId,
|
||||
);
|
||||
|
||||
const fileName = `avatar-${new Date().getTime()}.png`;
|
||||
const metadata = await sharp(buffer).metadata();
|
||||
const extension = metadata.format === 'gif' ? 'gif' : 'png';
|
||||
|
||||
const timestamp = new Date().getTime();
|
||||
/** Unique filename with timestamp and optional agent ID */
|
||||
const fileName = agentId
|
||||
? `agent-${agentId}-avatar-${timestamp}.${extension}`
|
||||
: `avatar-${timestamp}.${extension}`;
|
||||
const urlRoute = `/images/${userId}/${fileName}`;
|
||||
const avatarPath = path.join(userDir, fileName);
|
||||
|
||||
|
|
@ -139,7 +147,8 @@ async function processLocalAvatar({ buffer, userId, manual }) {
|
|||
const isManual = manual === 'true';
|
||||
let url = `${urlRoute}?manual=${isManual}`;
|
||||
|
||||
if (isManual) {
|
||||
// Only update user record if this is a user avatar (manual === 'true')
|
||||
if (isManual && !agentId) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,238 +0,0 @@
|
|||
// ~/server/services/Files/MistralOCR/crud.js
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const FormData = require('form-data');
|
||||
const {
|
||||
FileSources,
|
||||
envVarRegex,
|
||||
extractEnvVariable,
|
||||
extractVariableName,
|
||||
} = require('librechat-data-provider');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { logger, createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils/axios');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
/**
|
||||
* Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory
|
||||
*
|
||||
* @param {Object} params Upload parameters
|
||||
* @param {string} params.filePath The path to the file on disk
|
||||
* @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath)
|
||||
* @param {string} params.apiKey Mistral API key
|
||||
* @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL
|
||||
* @returns {Promise<Object>} The response from Mistral API
|
||||
*/
|
||||
async function uploadDocumentToMistral({
|
||||
filePath,
|
||||
fileName = '',
|
||||
apiKey,
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
const form = new FormData();
|
||||
form.append('purpose', 'ocr');
|
||||
const actualFileName = fileName || path.basename(filePath);
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
form.append('file', fileStream, { filename: actualFileName });
|
||||
|
||||
return axios
|
||||
.post(`${baseURL}/files`, form, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
async function getSignedUrl({
|
||||
apiKey,
|
||||
fileId,
|
||||
expiry = 24,
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
return axios
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error fetching signed URL:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.apiKey
|
||||
* @param {string} params.url - The document or image URL
|
||||
* @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url'
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.baseURL]
|
||||
* @returns {Promise<OCRResult>}
|
||||
*/
|
||||
async function performOCR({
|
||||
apiKey,
|
||||
url,
|
||||
documentType = 'document_url',
|
||||
model = 'mistral-ocr-latest',
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
|
||||
return axios
|
||||
.post(
|
||||
`${baseURL}/ocr`,
|
||||
{
|
||||
model,
|
||||
image_limit: 0,
|
||||
include_image_base64: false,
|
||||
document: {
|
||||
type: documentType,
|
||||
[documentKey]: url,
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
},
|
||||
)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error performing OCR:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a file to the Mistral OCR API and processes the OCR result.
|
||||
*
|
||||
* @param {Object} params - The params object.
|
||||
* @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id`
|
||||
* representing the user
|
||||
* @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should
|
||||
* have a `mimetype` property that tells us the file type
|
||||
* @param {string} params.file_id - The file ID.
|
||||
* @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency.
|
||||
* @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used),
|
||||
* along with the `filename` and `bytes` properties.
|
||||
*/
|
||||
const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
|
||||
try {
|
||||
/** @type {TCustomConfig['ocr']} */
|
||||
const ocrConfig = req.app.locals?.ocr;
|
||||
|
||||
const apiKeyConfig = ocrConfig.apiKey || '';
|
||||
const baseURLConfig = ocrConfig.baseURL || '';
|
||||
|
||||
const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig);
|
||||
const isBaseURLEnvVar = envVarRegex.test(baseURLConfig);
|
||||
|
||||
const isApiKeyEmpty = !apiKeyConfig.trim();
|
||||
const isBaseURLEmpty = !baseURLConfig.trim();
|
||||
|
||||
let apiKey, baseURL;
|
||||
|
||||
if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) {
|
||||
const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY';
|
||||
const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL';
|
||||
|
||||
const authValues = await loadAuthValues({
|
||||
userId: req.user.id,
|
||||
authFields: [baseURLVarName, apiKeyVarName],
|
||||
optional: new Set([baseURLVarName]),
|
||||
});
|
||||
|
||||
apiKey = authValues[apiKeyVarName];
|
||||
baseURL = authValues[baseURLVarName];
|
||||
} else {
|
||||
apiKey = apiKeyConfig;
|
||||
baseURL = baseURLConfig;
|
||||
}
|
||||
|
||||
const mistralFile = await uploadDocumentToMistral({
|
||||
filePath: file.path,
|
||||
fileName: file.originalname,
|
||||
apiKey,
|
||||
baseURL,
|
||||
});
|
||||
|
||||
const modelConfig = ocrConfig.mistralModel || '';
|
||||
const model = envVarRegex.test(modelConfig)
|
||||
? extractEnvVariable(modelConfig)
|
||||
: modelConfig.trim() || 'mistral-ocr-latest';
|
||||
|
||||
const signedUrlResponse = await getSignedUrl({
|
||||
apiKey,
|
||||
baseURL,
|
||||
fileId: mistralFile.id,
|
||||
});
|
||||
|
||||
const mimetype = (file.mimetype || '').toLowerCase();
|
||||
const originalname = file.originalname || '';
|
||||
const isImage =
|
||||
mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname);
|
||||
const documentType = isImage ? 'image_url' : 'document_url';
|
||||
|
||||
const ocrResult = await performOCR({
|
||||
apiKey,
|
||||
baseURL,
|
||||
model,
|
||||
url: signedUrlResponse.url,
|
||||
documentType,
|
||||
});
|
||||
|
||||
let aggregatedText = '';
|
||||
const images = [];
|
||||
ocrResult.pages.forEach((page, index) => {
|
||||
if (ocrResult.pages.length > 1) {
|
||||
aggregatedText += `# PAGE ${index + 1}\n`;
|
||||
}
|
||||
|
||||
aggregatedText += page.markdown + '\n\n';
|
||||
|
||||
if (page.images && page.images.length > 0) {
|
||||
page.images.forEach((image) => {
|
||||
if (image.image_base64) {
|
||||
images.push(image.image_base64);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
filename: file.originalname,
|
||||
bytes: aggregatedText.length * 4,
|
||||
filepath: FileSources.mistral_ocr,
|
||||
text: aggregatedText,
|
||||
images,
|
||||
};
|
||||
} catch (error) {
|
||||
let message = 'Error uploading document to Mistral OCR API';
|
||||
const detail = error?.response?.data?.detail;
|
||||
if (detail && detail !== '') {
|
||||
message = detail;
|
||||
}
|
||||
|
||||
const responseMessage = error?.response?.data?.message;
|
||||
throw new Error(
|
||||
`${logAxiosError({ error, message })}${responseMessage && responseMessage !== '' ? ` - ${responseMessage}` : ''}`,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
uploadDocumentToMistral,
|
||||
uploadMistralOCR,
|
||||
getSignedUrl,
|
||||
performOCR,
|
||||
};
|
||||
|
|
@ -1,848 +0,0 @@
|
|||
const fs = require('fs');
|
||||
|
||||
const mockAxios = {
|
||||
interceptors: {
|
||||
request: { use: jest.fn(), eject: jest.fn() },
|
||||
response: { use: jest.fn(), eject: jest.fn() },
|
||||
},
|
||||
create: jest.fn().mockReturnValue({
|
||||
defaults: {
|
||||
proxy: null,
|
||||
},
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
}),
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
reset: jest.fn().mockImplementation(function () {
|
||||
this.get.mockClear();
|
||||
this.post.mockClear();
|
||||
this.put.mockClear();
|
||||
this.delete.mockClear();
|
||||
this.create.mockClear();
|
||||
}),
|
||||
};
|
||||
|
||||
jest.mock('axios', () => mockAxios);
|
||||
jest.mock('fs');
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
createAxiosInstance: () => mockAxios,
|
||||
}));
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud');
|
||||
|
||||
describe('MistralOCR Service', () => {
|
||||
afterEach(() => {
|
||||
mockAxios.reset();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('uploadDocumentToMistral', () => {
|
||||
beforeEach(() => {
|
||||
// Create a more complete mock for file streams that FormData can work with
|
||||
const mockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (event, handler) {
|
||||
// Simulate immediate 'end' event to make FormData complete processing
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function () {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
|
||||
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
|
||||
|
||||
// Mock FormData's append to avoid actual stream processing
|
||||
jest.mock('form-data', () => {
|
||||
const mockFormData = function () {
|
||||
return {
|
||||
append: jest.fn(),
|
||||
getHeaders: jest
|
||||
.fn()
|
||||
.mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }),
|
||||
getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')),
|
||||
getLength: jest.fn().mockReturnValue(100),
|
||||
};
|
||||
};
|
||||
return mockFormData;
|
||||
});
|
||||
});
|
||||
|
||||
it('should upload a document to Mistral API using file streaming', async () => {
|
||||
const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } };
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
// Check that createReadStream was called with the correct file path
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf');
|
||||
|
||||
// Since we're mocking FormData, we'll just check that axios was called correctly
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer test-api-key',
|
||||
}),
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
}),
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors during document upload', async () => {
|
||||
const errorMessage = 'API error';
|
||||
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
).rejects.toThrow(errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSignedUrl', () => {
|
||||
it('should fetch signed URL from Mistral API', async () => {
|
||||
const mockResponse = { data: { url: 'https://document-url.com' } };
|
||||
mockAxios.get.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.get).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
|
||||
{
|
||||
headers: {
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors when fetching signed URL', async () => {
|
||||
const errorMessage = 'API error';
|
||||
mockAxios.get.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('performOCR', () => {
|
||||
it('should perform OCR using Mistral API (document_url)', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }],
|
||||
},
|
||||
};
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://document-url.com',
|
||||
model: 'mistral-ocr-latest',
|
||||
documentType: 'document_url',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
{
|
||||
model: 'mistral-ocr-latest',
|
||||
include_image_base64: false,
|
||||
image_limit: 0,
|
||||
document: {
|
||||
type: 'document_url',
|
||||
document_url: 'https://document-url.com',
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should perform OCR using Mistral API (image_url)', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
pages: [{ markdown: 'Image OCR content' }],
|
||||
},
|
||||
};
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://image-url.com/image.png',
|
||||
model: 'mistral-ocr-latest',
|
||||
documentType: 'image_url',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
{
|
||||
model: 'mistral-ocr-latest',
|
||||
include_image_base64: false,
|
||||
image_limit: 0,
|
||||
document: {
|
||||
type: 'image_url',
|
||||
image_url: 'https://image-url.com/image.png',
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors during OCR processing', async () => {
|
||||
const errorMessage = 'OCR processing error';
|
||||
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://document-url.com',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('uploadMistralOCR', () => {
|
||||
beforeEach(() => {
|
||||
const mockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (event, handler) {
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function () {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
|
||||
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
|
||||
});
|
||||
|
||||
it('should process OCR for a file with standard configuration', async () => {
|
||||
// Setup mocks
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock file upload response
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
|
||||
// Mock signed URL response
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
|
||||
// Mock OCR response with text and images
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [
|
||||
{
|
||||
markdown: 'Page 1 content',
|
||||
images: [{ image_base64: 'base64image1' }],
|
||||
},
|
||||
{
|
||||
markdown: 'Page 2 content',
|
||||
images: [{ image_base64: 'base64image2' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Use environment variable syntax to ensure loadAuthValues is called
|
||||
apiKey: '${OCR_API_KEY}',
|
||||
baseURL: '${OCR_BASEURL}',
|
||||
mistralModel: 'mistral-medium',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
mimetype: 'application/pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Verify OCR result
|
||||
expect(result).toEqual({
|
||||
filename: 'document.pdf',
|
||||
bytes: expect.any(Number),
|
||||
filepath: 'mistral_ocr',
|
||||
text: expect.stringContaining('# PAGE 1'),
|
||||
images: ['base64image1', 'base64image2'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should process OCR for an image file and use image_url type', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock file upload response
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-456', purpose: 'ocr' },
|
||||
});
|
||||
|
||||
// Mock signed URL response
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com/image.png' },
|
||||
});
|
||||
|
||||
// Mock OCR response for image
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [
|
||||
{
|
||||
markdown: 'Image OCR result',
|
||||
images: [{ image_base64: 'imgbase64' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user456' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: '${OCR_API_KEY}',
|
||||
baseURL: '${OCR_BASEURL}',
|
||||
mistralModel: 'mistral-medium',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/image.png',
|
||||
originalname: 'image.png',
|
||||
mimetype: 'image/png',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file456',
|
||||
entity_id: 'entity456',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png');
|
||||
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user456',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Check that the OCR API was called with image_url type
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
document: expect.objectContaining({
|
||||
type: 'image_url',
|
||||
image_url: 'https://signed-url.com/image.png',
|
||||
}),
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
filename: 'image.png',
|
||||
bytes: expect.any(Number),
|
||||
filepath: 'mistral_ocr',
|
||||
text: expect.stringContaining('Image OCR result'),
|
||||
images: ['imgbase64'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should process variable references in configuration', async () => {
|
||||
// Setup mocks with environment variables
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
CUSTOM_API_KEY: 'custom-api-key',
|
||||
CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock API responses
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [{ markdown: 'Content from custom API' }],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: '${CUSTOM_API_KEY}',
|
||||
baseURL: '${CUSTOM_BASEURL}',
|
||||
mistralModel: '${CUSTOM_MODEL}',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Set environment variable for model
|
||||
process.env.CUSTOM_MODEL = 'mistral-large';
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify that custom environment variables were extracted and used
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Check that mistral-large was used in the OCR API call
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
model: 'mistral-large',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
expect(result.text).toEqual('Content from custom API\n\n');
|
||||
});
|
||||
|
||||
it('should fall back to default values when variables are not properly formatted', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'default-api-key',
|
||||
OCR_BASEURL: undefined, // Testing optional parameter
|
||||
});
|
||||
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [{ markdown: 'Default API result' }],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Use environment variable syntax to ensure loadAuthValues is called
|
||||
apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name
|
||||
baseURL: '${OCR_BASEURL}', // Using valid env var format
|
||||
mistralModel: 'mistral-ocr-latest', // Plain string value
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Should use the default values
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'INVALID_FORMAT'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Should use the default model when not using environment variable format
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
model: 'mistral-ocr-latest',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle API errors during OCR process', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
});
|
||||
|
||||
// Mock file upload to fail
|
||||
mockAxios.post.mockRejectedValueOnce(new Error('Upload failed'));
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: 'OCR_API_KEY',
|
||||
baseURL: 'OCR_BASEURL',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
await expect(
|
||||
uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
}),
|
||||
).rejects.toThrow('Error uploading document to Mistral OCR API');
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
});
|
||||
|
||||
it('should handle single page documents without page numbering', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included
|
||||
});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Single page content' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: 'OCR_API_KEY',
|
||||
baseURL: 'OCR_BASEURL',
|
||||
mistralModel: 'mistral-ocr-latest',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'single-page.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify that single page documents don't include page numbering
|
||||
expect(result.text).not.toContain('# PAGE');
|
||||
expect(result.text).toEqual('Single page content\n\n');
|
||||
});
|
||||
|
||||
it('should use literal values in configuration when provided directly', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
// We'll still mock this but it should not be used for literal values
|
||||
loadAuthValues.mockResolvedValue({});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Processed with literal config values' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Direct values that should be used as-is, without variable substitution
|
||||
apiKey: 'actual-api-key-value',
|
||||
baseURL: 'https://direct-api-url.mistral.ai/v1',
|
||||
mistralModel: 'mistral-direct-model',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'direct-values.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify the correct URL was used with the direct baseURL value
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://direct-api-url.mistral.ai/v1/files',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer actual-api-key-value',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Check the OCR call was made with the direct model value
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://direct-api-url.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
model: 'mistral-direct-model',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Verify the result
|
||||
expect(result.text).toEqual('Processed with literal config values\n\n');
|
||||
|
||||
// Verify loadAuthValues was never called since we used direct values
|
||||
expect(loadAuthValues).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty configuration values and use defaults', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
// Set up the mock values to be returned by loadAuthValues
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'default-from-env-key',
|
||||
OCR_BASEURL: 'https://default-from-env.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Content from default configuration' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Empty string values - should fall back to defaults
|
||||
apiKey: '',
|
||||
baseURL: '',
|
||||
mistralModel: '',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'empty-config.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify loadAuthValues was called with the default variable names
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Verify the API calls used the default values from loadAuthValues
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://default-from-env.mistral.ai/v1/files',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer default-from-env-key',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify the OCR model defaulted to mistral-ocr-latest
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://default-from-env.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
model: 'mistral-ocr-latest',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Check result
|
||||
expect(result.text).toEqual('Content from default configuration\n\n');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
const crud = require('./crud');
|
||||
|
||||
module.exports = {
|
||||
...crud,
|
||||
};
|
||||
|
|
@ -94,15 +94,28 @@ async function prepareImageURLS3(req, file) {
|
|||
* @param {Buffer} params.buffer - Avatar image buffer.
|
||||
* @param {string} params.userId - User's unique identifier.
|
||||
* @param {string} params.manual - 'true' or 'false' flag for manual update.
|
||||
* @param {string} [params.agentId] - Optional agent ID if this is an agent avatar.
|
||||
* @param {string} [params.basePath='images'] - Base path in the bucket.
|
||||
* @returns {Promise<string>} Signed URL of the uploaded avatar.
|
||||
*/
|
||||
async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) {
|
||||
async function processS3Avatar({ buffer, userId, manual, agentId, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath });
|
||||
if (manual === 'true') {
|
||||
const metadata = await sharp(buffer).metadata();
|
||||
const extension = metadata.format === 'gif' ? 'gif' : 'png';
|
||||
const timestamp = new Date().getTime();
|
||||
|
||||
/** Unique filename with timestamp and optional agent ID */
|
||||
const fileName = agentId
|
||||
? `agent-${agentId}-avatar-${timestamp}.${extension}`
|
||||
: `avatar-${timestamp}.${extension}`;
|
||||
|
||||
const downloadURL = await saveBufferToS3({ userId, buffer, fileName, basePath });
|
||||
|
||||
// Only update user record if this is a user avatar (manual === 'true')
|
||||
if (manual === 'true' && !agentId) {
|
||||
await updateUser(userId, { avatar: downloadURL });
|
||||
}
|
||||
|
||||
return downloadURL;
|
||||
} catch (error) {
|
||||
logger.error('[processS3Avatar] Error processing S3 avatar:', error.message);
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
const fs = require('fs');
|
||||
const axios = require('axios');
|
||||
const FormData = require('form-data');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { FileSources } = require('librechat-data-provider');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Deletes a file from the vector database. This function takes a file object, constructs the full path, and
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const axios = require('axios');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const {
|
||||
FileSources,
|
||||
VisionModes,
|
||||
|
|
@ -7,8 +8,6 @@ const {
|
|||
EModelEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Converts a readable stream to a base64 encoded string.
|
||||
|
|
|
|||
|
|
@ -519,7 +519,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
throw new Error('OCR capability is not enabled for Agents');
|
||||
}
|
||||
|
||||
const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions(
|
||||
const { handleFileUpload: uploadOCR } = getStrategyFunctions(
|
||||
req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr,
|
||||
);
|
||||
const { file_id, temp_file_id } = metadata;
|
||||
|
|
@ -531,7 +531,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
images,
|
||||
filename,
|
||||
filepath: ocrFileURL,
|
||||
} = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath });
|
||||
} = await uploadOCR({ req, file, loadAuthValues });
|
||||
|
||||
const fileInfo = removeNullishValues({
|
||||
text,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { FileSources } = require('librechat-data-provider');
|
||||
const { uploadMistralOCR, uploadAzureMistralOCR } = require('@librechat/api');
|
||||
const {
|
||||
getFirebaseURL,
|
||||
prepareImageURL,
|
||||
|
|
@ -46,7 +47,6 @@ const {
|
|||
const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI');
|
||||
const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code');
|
||||
const { uploadVectors, deleteVectors } = require('./VectorDB');
|
||||
const { uploadMistralOCR } = require('./MistralOCR');
|
||||
|
||||
/**
|
||||
* Firebase Storage Strategy Functions
|
||||
|
|
@ -190,6 +190,26 @@ const mistralOCRStrategy = () => ({
|
|||
handleFileUpload: uploadMistralOCR,
|
||||
});
|
||||
|
||||
const azureMistralOCRStrategy = () => ({
|
||||
/** @type {typeof saveFileFromURL | null} */
|
||||
saveURL: null,
|
||||
/** @type {typeof getLocalFileURL | null} */
|
||||
getFileURL: null,
|
||||
/** @type {typeof saveLocalBuffer | null} */
|
||||
saveBuffer: null,
|
||||
/** @type {typeof processLocalAvatar | null} */
|
||||
processAvatar: null,
|
||||
/** @type {typeof uploadLocalImage | null} */
|
||||
handleImageUpload: null,
|
||||
/** @type {typeof prepareImagesLocal | null} */
|
||||
prepareImagePayload: null,
|
||||
/** @type {typeof deleteLocalFile | null} */
|
||||
deleteFile: null,
|
||||
/** @type {typeof getLocalFileStream | null} */
|
||||
getDownloadStream: null,
|
||||
handleFileUpload: uploadAzureMistralOCR,
|
||||
});
|
||||
|
||||
// Strategy Selector
|
||||
const getStrategyFunctions = (fileSource) => {
|
||||
if (fileSource === FileSources.firebase) {
|
||||
|
|
@ -210,6 +230,8 @@ const getStrategyFunctions = (fileSource) => {
|
|||
return codeOutputStrategy();
|
||||
} else if (fileSource === FileSources.mistral_ocr) {
|
||||
return mistralOCRStrategy();
|
||||
} else if (fileSource === FileSources.azure_mistral_ocr) {
|
||||
return azureMistralOCRStrategy();
|
||||
} else {
|
||||
throw new Error('Invalid file source');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { z } = require('zod');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { normalizeServerName } = require('librechat-mcp');
|
||||
const { normalizeServerName } = require('@librechat/api');
|
||||
const { Constants: AgentConstants, Providers } = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
|
|
@ -50,9 +50,10 @@ async function createMCPTool({ req, toolKey, provider: _provider }) {
|
|||
|
||||
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolArguments, config) => {
|
||||
const userId = config?.configurable?.user?.id || config?.configurable?.user_id;
|
||||
try {
|
||||
const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
|
||||
const mcpManager = getMCPManager(config?.configurable?.user_id);
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
|
||||
const result = await mcpManager.callTool({
|
||||
serverName,
|
||||
|
|
@ -60,8 +61,8 @@ async function createMCPTool({ req, toolKey, provider: _provider }) {
|
|||
provider,
|
||||
toolArguments,
|
||||
options: {
|
||||
userId: config?.configurable?.user_id,
|
||||
signal: derivedSignal,
|
||||
user: config?.configurable?.user,
|
||||
},
|
||||
});
|
||||
|
||||
|
|
@ -74,7 +75,7 @@ async function createMCPTool({ req, toolKey, provider: _provider }) {
|
|||
return result;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[MCP][User: ${config?.configurable?.user_id}][${serverName}] Error calling "${toolName}" MCP tool:`,
|
||||
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`,
|
||||
error,
|
||||
);
|
||||
throw new Error(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
const axios = require('axios');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
|
||||
const { inputSchema, logAxiosError, extractBaseURL, processModelData } = require('~/utils');
|
||||
const { inputSchema, extractBaseURL, processModelData } = require('~/utils');
|
||||
const { OllamaClient } = require('~/app/clients/OllamaClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Splits a string by commas and trims each resulting value.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const axios = require('axios');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
fetchModels,
|
||||
|
|
@ -28,7 +28,8 @@ jest.mock('~/cache/getLogStores', () =>
|
|||
set: jest.fn().mockResolvedValue(true),
|
||||
})),
|
||||
);
|
||||
jest.mock('~/config', () => ({
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const axios = require('axios');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
|
||||
/**
|
||||
* @typedef {Object} RetrieveOptions
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
const axios = require('axios');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { TokenExchangeMethodEnum } = require('librechat-data-provider');
|
||||
const { handleOAuthToken } = require('~/models/Token');
|
||||
const { decryptV2 } = require('~/server/utils/crypto');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Processes the access tokens and stores them in the database.
|
||||
|
|
@ -49,6 +50,7 @@ async function processAccessTokens(tokenData, { userId, identifier }) {
|
|||
* @param {string} fields.client_url - The URL of the OAuth provider.
|
||||
* @param {string} fields.identifier - The identifier for the token.
|
||||
* @param {string} fields.refresh_token - The refresh token to use.
|
||||
* @param {string} fields.token_exchange_method - The token exchange method ('default_post' or 'basic_auth_header').
|
||||
* @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider.
|
||||
* @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider.
|
||||
* @returns {Promise<{
|
||||
|
|
@ -63,26 +65,36 @@ const refreshAccessToken = async ({
|
|||
client_url,
|
||||
identifier,
|
||||
refresh_token,
|
||||
token_exchange_method,
|
||||
encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret,
|
||||
}) => {
|
||||
try {
|
||||
const oauth_client_id = await decryptV2(encrypted_oauth_client_id);
|
||||
const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret);
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
};
|
||||
|
||||
const params = new URLSearchParams({
|
||||
client_id: oauth_client_id,
|
||||
client_secret: oauth_client_secret,
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token,
|
||||
});
|
||||
|
||||
if (token_exchange_method === TokenExchangeMethodEnum.BasicAuthHeader) {
|
||||
const basicAuth = Buffer.from(`${oauth_client_id}:${oauth_client_secret}`).toString('base64');
|
||||
headers['Authorization'] = `Basic ${basicAuth}`;
|
||||
} else {
|
||||
params.append('client_id', oauth_client_id);
|
||||
params.append('client_secret', oauth_client_secret);
|
||||
}
|
||||
|
||||
const response = await axios({
|
||||
method: 'POST',
|
||||
url: client_url,
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
},
|
||||
headers,
|
||||
data: params.toString(),
|
||||
});
|
||||
await processAccessTokens(response.data, {
|
||||
|
|
@ -110,6 +122,7 @@ const refreshAccessToken = async ({
|
|||
* @param {string} fields.identifier - The identifier for the token.
|
||||
* @param {string} fields.client_url - The URL of the OAuth provider.
|
||||
* @param {string} fields.redirect_uri - The redirect URI for the OAuth provider.
|
||||
* @param {string} fields.token_exchange_method - The token exchange method ('default_post' or 'basic_auth_header').
|
||||
* @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider.
|
||||
* @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider.
|
||||
* @returns {Promise<{
|
||||
|
|
@ -125,27 +138,37 @@ const getAccessToken = async ({
|
|||
identifier,
|
||||
client_url,
|
||||
redirect_uri,
|
||||
token_exchange_method,
|
||||
encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret,
|
||||
}) => {
|
||||
const oauth_client_id = await decryptV2(encrypted_oauth_client_id);
|
||||
const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret);
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
};
|
||||
|
||||
const params = new URLSearchParams({
|
||||
code,
|
||||
client_id: oauth_client_id,
|
||||
client_secret: oauth_client_secret,
|
||||
grant_type: 'authorization_code',
|
||||
redirect_uri,
|
||||
});
|
||||
|
||||
if (token_exchange_method === TokenExchangeMethodEnum.BasicAuthHeader) {
|
||||
const basicAuth = Buffer.from(`${oauth_client_id}:${oauth_client_secret}`).toString('base64');
|
||||
headers['Authorization'] = `Basic ${basicAuth}`;
|
||||
} else {
|
||||
params.append('client_id', oauth_client_id);
|
||||
params.append('client_secret', oauth_client_secret);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios({
|
||||
method: 'POST',
|
||||
url: client_url,
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
},
|
||||
headers,
|
||||
data: params.toString(),
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,64 +0,0 @@
|
|||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
class Tokenizer {
|
||||
constructor() {
|
||||
this.tokenizersCache = {};
|
||||
this.tokenizerCallsCount = 0;
|
||||
}
|
||||
|
||||
getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||
let tokenizer;
|
||||
if (this.tokenizersCache[encoding]) {
|
||||
tokenizer = this.tokenizersCache[encoding];
|
||||
} else {
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
||||
}
|
||||
this.tokenizersCache[encoding] = tokenizer;
|
||||
}
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
freeAndResetAllEncoders() {
|
||||
try {
|
||||
Object.keys(this.tokenizersCache).forEach((key) => {
|
||||
if (this.tokenizersCache[key]) {
|
||||
this.tokenizersCache[key].free();
|
||||
delete this.tokenizersCache[key];
|
||||
}
|
||||
});
|
||||
this.tokenizerCallsCount = 1;
|
||||
} catch (error) {
|
||||
logger.error('[Tokenizer] Free and reset encoders error', error);
|
||||
}
|
||||
}
|
||||
|
||||
resetTokenizersIfNecessary() {
|
||||
if (this.tokenizerCallsCount >= 25) {
|
||||
if (this.options?.debug) {
|
||||
logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...');
|
||||
}
|
||||
this.freeAndResetAllEncoders();
|
||||
}
|
||||
this.tokenizerCallsCount++;
|
||||
}
|
||||
|
||||
getTokenCount(text, encoding = 'cl100k_base') {
|
||||
this.resetTokenizersIfNecessary();
|
||||
try {
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
} catch (error) {
|
||||
this.freeAndResetAllEncoders();
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const TokenizerSingleton = new Tokenizer();
|
||||
|
||||
module.exports = TokenizerSingleton;
|
||||
|
|
@ -1,136 +0,0 @@
|
|||
/**
|
||||
* @file Tokenizer.spec.cjs
|
||||
*
|
||||
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
|
||||
* Make sure to install `tiktoken` and have it configured properly.
|
||||
*/
|
||||
|
||||
const Tokenizer = require('./Tokenizer'); // <-- Adjust path to your singleton file
|
||||
const { logger } = require('~/config');
|
||||
|
||||
describe('Tokenizer', () => {
|
||||
it('should be a singleton (same instance)', () => {
|
||||
const AnotherTokenizer = require('./Tokenizer'); // same path
|
||||
expect(Tokenizer).toBe(AnotherTokenizer);
|
||||
});
|
||||
|
||||
describe('getTokenizer', () => {
|
||||
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
|
||||
// The real `encoding_for_model` will be called internally
|
||||
// as soon as we pass isModelName = true.
|
||||
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
|
||||
|
||||
// Basic sanity checks
|
||||
expect(tokenizer).toBeDefined();
|
||||
// You can optionally check certain properties from `tiktoken` if they exist
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
|
||||
// The real `get_encoding` will be called internally
|
||||
// as soon as we pass isModelName = false.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
|
||||
expect(tokenizer).toBeDefined();
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should return cached tokenizer if previously fetched', () => {
|
||||
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
// Should be the exact same instance from the cache
|
||||
expect(tokenizer1).toBe(tokenizer2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('freeAndResetAllEncoders', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
|
||||
// By creating two different encodings, we populate the cache
|
||||
Tokenizer.getTokenizer('cl100k_base', false);
|
||||
Tokenizer.getTokenizer('r50k_base', false);
|
||||
|
||||
// Now free them
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// The internal cache is cleared
|
||||
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
|
||||
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
|
||||
|
||||
// tokenizerCallsCount is reset to 1
|
||||
expect(Tokenizer.tokenizerCallsCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should catch and log errors if freeing fails', () => {
|
||||
// Mock logger.error before the test
|
||||
const mockLoggerError = jest.spyOn(logger, 'error');
|
||||
|
||||
// Set up a problematic tokenizer in the cache
|
||||
Tokenizer.tokenizersCache['cl100k_base'] = {
|
||||
free() {
|
||||
throw new Error('Intentional free error');
|
||||
},
|
||||
};
|
||||
|
||||
// Should not throw uncaught errors
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// Verify logger.error was called with correct arguments
|
||||
expect(mockLoggerError).toHaveBeenCalledWith(
|
||||
'[Tokenizer] Free and reset encoders error',
|
||||
expect.any(Error),
|
||||
);
|
||||
|
||||
// Clean up
|
||||
mockLoggerError.mockRestore();
|
||||
Tokenizer.tokenizersCache = {};
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokenCount', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
});
|
||||
|
||||
it('should return the number of tokens in the given text', () => {
|
||||
const text = 'Hello, world!';
|
||||
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should reset encoders if an error is thrown', () => {
|
||||
// We can simulate an error by temporarily overriding the selected tokenizer’s `encode` method.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const originalEncode = tokenizer.encode;
|
||||
tokenizer.encode = () => {
|
||||
throw new Error('Forced error');
|
||||
};
|
||||
|
||||
// Despite the forced error, the code should catch and reset, then re-encode
|
||||
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
|
||||
// Restore the original encode
|
||||
tokenizer.encode = originalEncode;
|
||||
});
|
||||
|
||||
it('should reset tokenizers after 25 calls', () => {
|
||||
// Spy on freeAndResetAllEncoders
|
||||
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
|
||||
|
||||
// Make 24 calls; should NOT reset yet
|
||||
for (let i = 0; i < 24; i++) {
|
||||
Tokenizer.getTokenCount('test text', 'cl100k_base');
|
||||
}
|
||||
expect(resetSpy).not.toHaveBeenCalled();
|
||||
|
||||
// 25th call triggers the reset
|
||||
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
|
||||
expect(resetSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -500,6 +500,8 @@ async function processRequiredActions(client, requiredActions) {
|
|||
async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) {
|
||||
if (!agent.tools || agent.tools.length === 0) {
|
||||
return {};
|
||||
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
const { EModelEndpoint, agentsEndpointSChema } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Sets up the Agents configuration from the config (`librechat.yaml`) file.
|
||||
* @param {TCustomConfig} config - The loaded custom configuration.
|
||||
* @returns {Partial<TAgentsEndpoint>} The Agents endpoint configuration.
|
||||
*/
|
||||
function agentsConfigSetup(config) {
|
||||
const agentsConfig = config.endpoints[EModelEndpoint.agents];
|
||||
const parsedConfig = agentsEndpointSChema.parse(agentsConfig);
|
||||
return parsedConfig;
|
||||
}
|
||||
|
||||
module.exports = { agentsConfigSetup };
|
||||
|
|
@ -2,6 +2,7 @@ const {
|
|||
SystemRoles,
|
||||
Permissions,
|
||||
PermissionTypes,
|
||||
isMemoryEnabled,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { updateAccessPermissions } = require('~/models/Role');
|
||||
|
|
@ -20,6 +21,14 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
const hasModelSpecs = config?.modelSpecs?.list?.length > 0;
|
||||
const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0;
|
||||
|
||||
const memoryConfig = config?.memory;
|
||||
const memoryEnabled = isMemoryEnabled(memoryConfig);
|
||||
/** Only disable memories if memory config is present but disabled/invalid */
|
||||
const shouldDisableMemories = memoryConfig && !memoryEnabled;
|
||||
/** Check if personalization is enabled (defaults to true if memory is configured and enabled) */
|
||||
const isPersonalizationEnabled =
|
||||
memoryConfig && memoryEnabled && memoryConfig.personalize !== false;
|
||||
|
||||
/** @type {TCustomConfig['interface']} */
|
||||
const loadedInterface = removeNullishValues({
|
||||
endpointsMenu:
|
||||
|
|
@ -33,6 +42,7 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
privacyPolicy: interfaceConfig?.privacyPolicy ?? defaults.privacyPolicy,
|
||||
termsOfService: interfaceConfig?.termsOfService ?? defaults.termsOfService,
|
||||
bookmarks: interfaceConfig?.bookmarks ?? defaults.bookmarks,
|
||||
memories: shouldDisableMemories ? false : (interfaceConfig?.memories ?? defaults.memories),
|
||||
prompts: interfaceConfig?.prompts ?? defaults.prompts,
|
||||
multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo,
|
||||
agents: interfaceConfig?.agents ?? defaults.agents,
|
||||
|
|
@ -45,6 +55,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
await updateAccessPermissions(roleName, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks },
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: loadedInterface.memories,
|
||||
[Permissions.OPT_OUT]: isPersonalizationEnabled,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat },
|
||||
|
|
@ -54,6 +68,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
await updateAccessPermissions(SystemRoles.ADMIN, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks },
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: loadedInterface.memories,
|
||||
[Permissions.OPT_OUT]: isPersonalizationEnabled,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat },
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: true,
|
||||
memories: true,
|
||||
multiConvo: true,
|
||||
agents: true,
|
||||
temporaryChat: true,
|
||||
|
|
@ -26,6 +27,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
@ -39,6 +41,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: false,
|
||||
bookmarks: false,
|
||||
memories: false,
|
||||
multiConvo: false,
|
||||
agents: false,
|
||||
temporaryChat: false,
|
||||
|
|
@ -53,6 +56,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false },
|
||||
|
|
@ -70,6 +74,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -83,6 +88,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: undefined,
|
||||
bookmarks: undefined,
|
||||
memories: undefined,
|
||||
multiConvo: undefined,
|
||||
agents: undefined,
|
||||
temporaryChat: undefined,
|
||||
|
|
@ -97,6 +103,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -110,6 +117,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: false,
|
||||
memories: true,
|
||||
multiConvo: undefined,
|
||||
agents: true,
|
||||
temporaryChat: undefined,
|
||||
|
|
@ -124,6 +132,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -138,6 +147,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: true,
|
||||
memories: true,
|
||||
multiConvo: true,
|
||||
agents: true,
|
||||
temporaryChat: true,
|
||||
|
|
@ -151,6 +161,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
@ -168,6 +179,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -185,6 +197,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -202,6 +215,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -215,6 +229,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: false,
|
||||
memories: true,
|
||||
multiConvo: true,
|
||||
agents: false,
|
||||
temporaryChat: true,
|
||||
|
|
@ -228,6 +243,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
@ -242,6 +258,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: true,
|
||||
memories: false,
|
||||
multiConvo: false,
|
||||
agents: undefined,
|
||||
temporaryChat: undefined,
|
||||
|
|
@ -255,6 +272,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -268,6 +286,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: false,
|
||||
memories: true,
|
||||
multiConvo: true,
|
||||
agents: false,
|
||||
temporaryChat: true,
|
||||
|
|
@ -281,6 +300,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
const path = require('path');
|
||||
const crypto = require('crypto');
|
||||
const {
|
||||
Capabilities,
|
||||
EModelEndpoint,
|
||||
|
|
@ -218,38 +216,6 @@ function normalizeEndpointName(name = '') {
|
|||
return name.toLowerCase() === Providers.OLLAMA ? Providers.OLLAMA : name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize a filename by removing any directory components, replacing non-alphanumeric characters
|
||||
* @param {string} inputName
|
||||
* @returns {string}
|
||||
*/
|
||||
function sanitizeFilename(inputName) {
|
||||
// Remove any directory components
|
||||
let name = path.basename(inputName);
|
||||
|
||||
// Replace any non-alphanumeric characters except for '.' and '-'
|
||||
name = name.replace(/[^a-zA-Z0-9.-]/g, '_');
|
||||
|
||||
// Ensure the name doesn't start with a dot (hidden file in Unix-like systems)
|
||||
if (name.startsWith('.') || name === '') {
|
||||
name = '_' + name;
|
||||
}
|
||||
|
||||
// Limit the length of the filename
|
||||
const MAX_LENGTH = 255;
|
||||
if (name.length > MAX_LENGTH) {
|
||||
const ext = path.extname(name);
|
||||
const nameWithoutExt = path.basename(name, ext);
|
||||
name =
|
||||
nameWithoutExt.slice(0, MAX_LENGTH - ext.length - 7) +
|
||||
'-' +
|
||||
crypto.randomBytes(3).toString('hex') +
|
||||
ext;
|
||||
}
|
||||
|
||||
return name;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
isEnabled,
|
||||
handleText,
|
||||
|
|
@ -260,6 +226,5 @@ module.exports = {
|
|||
generateConfig,
|
||||
addSpaceIfNeeded,
|
||||
createOnProgress,
|
||||
sanitizeFilename,
|
||||
normalizeEndpointName,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,103 +0,0 @@
|
|||
const { isEnabled, sanitizeFilename } = require('./handleText');
|
||||
|
||||
describe('isEnabled', () => {
|
||||
test('should return true when input is "true"', () => {
|
||||
expect(isEnabled('true')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true when input is "TRUE"', () => {
|
||||
expect(isEnabled('TRUE')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true when input is true', () => {
|
||||
expect(isEnabled(true)).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false when input is "false"', () => {
|
||||
expect(isEnabled('false')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is false', () => {
|
||||
expect(isEnabled(false)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is null', () => {
|
||||
expect(isEnabled(null)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is undefined', () => {
|
||||
expect(isEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an empty string', () => {
|
||||
expect(isEnabled('')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is a whitespace string', () => {
|
||||
expect(isEnabled(' ')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is a number', () => {
|
||||
expect(isEnabled(123)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an object', () => {
|
||||
expect(isEnabled({})).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an array', () => {
|
||||
expect(isEnabled([])).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
jest.mock('crypto', () => {
|
||||
const actualModule = jest.requireActual('crypto');
|
||||
return {
|
||||
...actualModule,
|
||||
randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')),
|
||||
};
|
||||
});
|
||||
|
||||
describe('sanitizeFilename', () => {
|
||||
test('removes directory components (1/2)', () => {
|
||||
expect(sanitizeFilename('/path/to/file.txt')).toBe('file.txt');
|
||||
});
|
||||
|
||||
test('removes directory components (2/2)', () => {
|
||||
expect(sanitizeFilename('../../../../file.txt')).toBe('file.txt');
|
||||
});
|
||||
|
||||
test('replaces non-alphanumeric characters', () => {
|
||||
expect(sanitizeFilename('file name@#$.txt')).toBe('file_name___.txt');
|
||||
});
|
||||
|
||||
test('preserves dots and hyphens', () => {
|
||||
expect(sanitizeFilename('file-name.with.dots.txt')).toBe('file-name.with.dots.txt');
|
||||
});
|
||||
|
||||
test('prepends underscore to filenames starting with a dot', () => {
|
||||
expect(sanitizeFilename('.hiddenfile')).toBe('_.hiddenfile');
|
||||
});
|
||||
|
||||
test('truncates long filenames', () => {
|
||||
const longName = 'a'.repeat(300) + '.txt';
|
||||
const result = sanitizeFilename(longName);
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^a+-abc123\.txt$/);
|
||||
});
|
||||
|
||||
test('handles filenames with no extension', () => {
|
||||
const longName = 'a'.repeat(300);
|
||||
const result = sanitizeFilename(longName);
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^a+-abc123$/);
|
||||
});
|
||||
|
||||
test('handles empty input', () => {
|
||||
expect(sanitizeFilename('')).toBe('_');
|
||||
});
|
||||
|
||||
test('handles input with only special characters', () => {
|
||||
expect(sanitizeFilename('@#$%^&*')).toBe('_______');
|
||||
});
|
||||
});
|
||||
|
|
@ -1,395 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { Strategy: AppleStrategy } = require('passport-apple');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
const { User } = require('~/db/models');
|
||||
|
||||
const findUser = jest.fn();
|
||||
jest.mock('jsonwebtoken');
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const actualModule = jest.requireActual('@librechat/data-schemas');
|
||||
return {
|
||||
...actualModule,
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
createMethods: jest.fn(() => {
|
||||
return { findUser };
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('../../packages/auth/src/strategies/helpers', () => {
|
||||
const actualModule = jest.requireActual('../../packages/auth/src/strategies/helpers');
|
||||
return {
|
||||
...actualModule,
|
||||
createSocialUser: jest.fn(),
|
||||
handleExistingUser: jest.fn(),
|
||||
};
|
||||
});
|
||||
jest.mock('~/server/utils', () => ({
|
||||
isEnabled: jest.fn(),
|
||||
}));
|
||||
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const {
|
||||
createSocialUser,
|
||||
handleExistingUser,
|
||||
} = require('../../packages/auth/src/strategies/helpers');
|
||||
const { socialLogin } = require('../../packages/auth/src/strategies/socialLogin');
|
||||
describe('Apple Login Strategy', () => {
|
||||
let mongoServer;
|
||||
let appleStrategyInstance;
|
||||
const OLD_ENV = process.env;
|
||||
let getProfileDetails;
|
||||
|
||||
// Start and stop in-memory MongoDB
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
process.env = OLD_ENV;
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Reset environment variables
|
||||
process.env = { ...OLD_ENV };
|
||||
process.env.APPLE_CLIENT_ID = 'fake_client_id';
|
||||
process.env.APPLE_TEAM_ID = 'fake_team_id';
|
||||
process.env.APPLE_CALLBACK_URL = '/auth/apple/callback';
|
||||
process.env.DOMAIN_SERVER = 'https://example.com';
|
||||
process.env.APPLE_KEY_ID = 'fake_key_id';
|
||||
process.env.APPLE_PRIVATE_KEY_PATH = '/path/to/fake/private/key';
|
||||
process.env.ALLOW_SOCIAL_REGISTRATION = 'true';
|
||||
|
||||
// Clear mocks and database
|
||||
jest.clearAllMocks();
|
||||
await User.deleteMany({});
|
||||
|
||||
// Define getProfileDetails within the test scope
|
||||
getProfileDetails = ({ idToken, profile }) => {
|
||||
if (!idToken) {
|
||||
logger.error('idToken is missing');
|
||||
throw new Error('idToken is missing');
|
||||
}
|
||||
|
||||
const decoded = jwt.decode(idToken);
|
||||
if (!decoded) {
|
||||
logger.error('Failed to decode idToken');
|
||||
throw new Error('idToken is invalid');
|
||||
}
|
||||
|
||||
console.log('Decoded token:', decoded);
|
||||
|
||||
logger.debug(`Decoded Apple JWT: ${JSON.stringify(decoded, null, 2)}`);
|
||||
|
||||
return {
|
||||
email: decoded.email,
|
||||
id: decoded.sub,
|
||||
avatarUrl: null, // Apple does not provide an avatar URL
|
||||
username: decoded.email ? decoded.email.split('@')[0].toLowerCase() : `user_${decoded.sub}`,
|
||||
name: decoded.name
|
||||
? `${decoded.name.firstName} ${decoded.name.lastName}`
|
||||
: profile.displayName || null,
|
||||
emailVerified: true, // Apple verifies the email
|
||||
};
|
||||
};
|
||||
|
||||
// Mock isEnabled based on environment variable
|
||||
isEnabled.mockImplementation((flag) => {
|
||||
if (flag === 'true') {
|
||||
return true;
|
||||
}
|
||||
if (flag === 'false') {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
// Initialize the strategy with the mocked getProfileDetails
|
||||
const appleLogin = socialLogin('apple', getProfileDetails);
|
||||
|
||||
appleStrategyInstance = new AppleStrategy(
|
||||
{
|
||||
clientID: process.env.APPLE_CLIENT_ID,
|
||||
teamID: process.env.APPLE_TEAM_ID,
|
||||
callbackURL: `${process.env.DOMAIN_SERVER}${process.env.APPLE_CALLBACK_URL}`,
|
||||
keyID: process.env.APPLE_KEY_ID,
|
||||
privateKeyLocation: process.env.APPLE_PRIVATE_KEY_PATH,
|
||||
passReqToCallback: false,
|
||||
},
|
||||
appleLogin,
|
||||
);
|
||||
});
|
||||
|
||||
const mockProfile = {
|
||||
displayName: 'John Doe',
|
||||
};
|
||||
|
||||
describe('getProfileDetails', () => {
|
||||
it('should throw an error if idToken is missing', () => {
|
||||
expect(() => {
|
||||
getProfileDetails({ idToken: null, profile: mockProfile });
|
||||
}).toThrow('idToken is missing');
|
||||
expect(logger.error).toHaveBeenCalledWith('idToken is missing');
|
||||
});
|
||||
|
||||
it('should throw an error if idToken cannot be decoded', () => {
|
||||
jwt.decode.mockReturnValue(null);
|
||||
expect(() => {
|
||||
getProfileDetails({ idToken: 'invalid_id_token', profile: mockProfile });
|
||||
}).toThrow('idToken is invalid');
|
||||
expect(logger.error).toHaveBeenCalledWith('Failed to decode idToken');
|
||||
});
|
||||
|
||||
it('should extract user details correctly from idToken', () => {
|
||||
const fakeDecodedToken = {
|
||||
email: 'john.doe@example.com',
|
||||
sub: 'apple-sub-1234',
|
||||
name: {
|
||||
firstName: 'John',
|
||||
lastName: 'Doe',
|
||||
},
|
||||
};
|
||||
|
||||
jwt.decode.mockReturnValue(fakeDecodedToken);
|
||||
|
||||
const profileDetails = getProfileDetails({
|
||||
idToken: 'fake_id_token',
|
||||
profile: mockProfile,
|
||||
});
|
||||
|
||||
expect(jwt.decode).toHaveBeenCalledWith('fake_id_token');
|
||||
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Decoded Apple JWT'));
|
||||
expect(profileDetails).toEqual({
|
||||
email: 'john.doe@example.com',
|
||||
id: 'apple-sub-1234',
|
||||
avatarUrl: null,
|
||||
username: 'john.doe',
|
||||
name: 'John Doe',
|
||||
emailVerified: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing email and use sub for username', () => {
|
||||
const fakeDecodedToken = {
|
||||
sub: 'apple-sub-5678',
|
||||
};
|
||||
|
||||
jwt.decode.mockReturnValue(fakeDecodedToken);
|
||||
|
||||
const profileDetails = getProfileDetails({
|
||||
idToken: 'fake_id_token',
|
||||
profile: mockProfile,
|
||||
});
|
||||
|
||||
expect(profileDetails).toEqual({
|
||||
email: undefined,
|
||||
id: 'apple-sub-5678',
|
||||
avatarUrl: null,
|
||||
username: 'user_apple-sub-5678',
|
||||
name: 'John Doe',
|
||||
emailVerified: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Strategy verify callback', () => {
|
||||
const tokenset = {
|
||||
id_token: 'fake_id_token',
|
||||
};
|
||||
|
||||
const decodedToken = {
|
||||
email: 'jane.doe@example.com',
|
||||
sub: 'apple-sub-9012',
|
||||
name: {
|
||||
firstName: 'Jane',
|
||||
lastName: 'Doe',
|
||||
},
|
||||
};
|
||||
|
||||
const fakeAccessToken = 'fake_access_token';
|
||||
const fakeRefreshToken = 'fake_refresh_token';
|
||||
|
||||
beforeEach(async () => {
|
||||
jwt.decode.mockReturnValue(decodedToken);
|
||||
findUser.mockResolvedValue(null);
|
||||
|
||||
const { initAuth } = require('../../packages/auth/src/initAuth');
|
||||
const saveBufferMock = jest.fn().mockResolvedValue('/fake/path/to/avatar.png');
|
||||
await initAuth(mongoose, { enabled: false }, saveBufferMock); // mongoose: {}, fake balance config, dummy saveBuffer
|
||||
});
|
||||
|
||||
it('should create a new user if one does not exist and registration is allowed', async () => {
|
||||
// Mock findUser to return null (user does not exist)
|
||||
findUser.mockResolvedValue(null);
|
||||
|
||||
// Mock createSocialUser to create a user
|
||||
createSocialUser.mockImplementation(async (userData) => {
|
||||
const user = new User(userData);
|
||||
await user.save();
|
||||
return user;
|
||||
});
|
||||
|
||||
const mockVerifyCallback = jest.fn();
|
||||
|
||||
// Invoke the verify callback with correct arguments
|
||||
await new Promise((resolve) => {
|
||||
appleStrategyInstance._verify(
|
||||
fakeAccessToken,
|
||||
fakeRefreshToken,
|
||||
tokenset.id_token,
|
||||
mockProfile,
|
||||
(err, user) => {
|
||||
mockVerifyCallback(err, user);
|
||||
resolve();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockVerifyCallback).toHaveBeenCalledWith(
|
||||
null,
|
||||
expect.objectContaining({ email: 'jane.doe@example.com' }),
|
||||
);
|
||||
const user = mockVerifyCallback.mock.calls[0][1];
|
||||
expect(user.email).toBe('jane.doe@example.com');
|
||||
expect(user.username).toBe('jane.doe');
|
||||
expect(user.name).toBe('Jane Doe');
|
||||
expect(user.provider).toBe('apple');
|
||||
});
|
||||
|
||||
it('should handle existing user and update avatarUrl', async () => {
|
||||
// Create an existing user without saving to database
|
||||
const existingUser = new User({
|
||||
email: 'jane.doe@example.com',
|
||||
username: 'jane.doe',
|
||||
name: 'Jane Doe',
|
||||
provider: 'apple',
|
||||
providerId: 'apple-sub-9012',
|
||||
avatarUrl: 'old_avatar.png',
|
||||
});
|
||||
|
||||
// Mock findUser to return the existing user
|
||||
findUser.mockResolvedValue(existingUser);
|
||||
|
||||
// Mock handleExistingUser to update avatarUrl without saving to database
|
||||
handleExistingUser.mockImplementation(async (user, avatarUrl) => {
|
||||
user.avatarUrl = avatarUrl;
|
||||
// Don't call save() to avoid database operations
|
||||
return user;
|
||||
});
|
||||
|
||||
const mockVerifyCallback = jest.fn();
|
||||
|
||||
// Invoke the verify callback with correct arguments
|
||||
await new Promise((resolve) => {
|
||||
appleStrategyInstance._verify(
|
||||
fakeAccessToken,
|
||||
fakeRefreshToken,
|
||||
tokenset.id_token,
|
||||
mockProfile,
|
||||
(err, user) => {
|
||||
mockVerifyCallback(err, user);
|
||||
resolve();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockVerifyCallback).toHaveBeenCalledWith(null, existingUser);
|
||||
expect(existingUser.avatarUrl).toBeNull(); // As per getProfileDetails
|
||||
expect(handleExistingUser).toHaveBeenCalledWith(existingUser, null);
|
||||
});
|
||||
|
||||
it('should handle missing idToken gracefully', async () => {
|
||||
const mockVerifyCallback = jest.fn();
|
||||
|
||||
// Invoke the verify callback with missing id_token
|
||||
await new Promise((resolve) => {
|
||||
appleStrategyInstance._verify(
|
||||
fakeAccessToken,
|
||||
fakeRefreshToken,
|
||||
null, // idToken is missing
|
||||
mockProfile,
|
||||
(err, user) => {
|
||||
mockVerifyCallback(err, user);
|
||||
resolve();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockVerifyCallback).toHaveBeenCalledWith(expect.any(Error), undefined);
|
||||
expect(mockVerifyCallback.mock.calls[0][0].message).toBe('idToken is missing');
|
||||
// Ensure createSocialUser and handleExistingUser were not called
|
||||
expect(createSocialUser).not.toHaveBeenCalled();
|
||||
expect(handleExistingUser).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle decoding errors gracefully', async () => {
|
||||
// Simulate decoding failure by returning null
|
||||
jwt.decode.mockReturnValue(null);
|
||||
|
||||
const mockVerifyCallback = jest.fn();
|
||||
|
||||
// Invoke the verify callback with correct arguments
|
||||
await new Promise((resolve) => {
|
||||
appleStrategyInstance._verify(
|
||||
fakeAccessToken,
|
||||
fakeRefreshToken,
|
||||
tokenset.id_token,
|
||||
mockProfile,
|
||||
(err, user) => {
|
||||
mockVerifyCallback(err, user);
|
||||
resolve();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockVerifyCallback).toHaveBeenCalledWith(expect.any(Error), undefined);
|
||||
expect(mockVerifyCallback.mock.calls[0][0].message).toBe('idToken is invalid');
|
||||
// Ensure createSocialUser and handleExistingUser were not called
|
||||
expect(createSocialUser).not.toHaveBeenCalled();
|
||||
expect(handleExistingUser).not.toHaveBeenCalled();
|
||||
// Ensure logger.error was called
|
||||
expect(logger.error).toHaveBeenCalledWith('Failed to decode idToken');
|
||||
});
|
||||
|
||||
it('should handle errors during user creation', async () => {
|
||||
// Mock findUser to return null (user does not exist)
|
||||
findUser.mockResolvedValue(null);
|
||||
|
||||
// Mock createSocialUser to throw an error
|
||||
createSocialUser.mockImplementation(() => {
|
||||
throw new Error('Database error');
|
||||
});
|
||||
|
||||
const mockVerifyCallback = jest.fn();
|
||||
|
||||
// Invoke the verify callback with correct arguments
|
||||
await new Promise((resolve) => {
|
||||
appleStrategyInstance._verify(
|
||||
fakeAccessToken,
|
||||
fakeRefreshToken,
|
||||
tokenset.id_token,
|
||||
mockProfile,
|
||||
(err, user) => {
|
||||
mockVerifyCallback(err, user);
|
||||
resolve();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockVerifyCallback).toHaveBeenCalledWith(expect.any(Error), undefined);
|
||||
expect(mockVerifyCallback.mock.calls[0][0].message).toBe('Database error');
|
||||
// Ensure logger.error was called
|
||||
expect(logger.error).toHaveBeenCalledWith('[appleLogin]', expect.any(Error));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,364 +0,0 @@
|
|||
const passport = require('passport');
|
||||
const mongoose = require('mongoose');
|
||||
// --- Mocks ---
|
||||
jest.mock('jsonwebtoken/decode');
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getBalanceConfig: jest.fn(() => ({
|
||||
enabled: false,
|
||||
})),
|
||||
}));
|
||||
|
||||
const mockedMethods = {
|
||||
findUser: jest.fn(),
|
||||
createUser: jest.fn(),
|
||||
updateUser: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const actual = jest.requireActual('@librechat/data-schemas');
|
||||
return {
|
||||
...actual,
|
||||
createMethods: jest.fn(() => mockedMethods),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('~/server/utils/crypto', () => ({
|
||||
hashToken: jest.fn().mockResolvedValue('hashed-token'),
|
||||
}));
|
||||
jest.mock('~/server/utils', () => ({
|
||||
isEnabled: jest.fn(() => false),
|
||||
}));
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
}));
|
||||
jest.mock('~/cache/getLogStores', () =>
|
||||
jest.fn(() => ({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
})),
|
||||
);
|
||||
|
||||
// Mock the openid-client module and all its dependencies
|
||||
jest.mock('openid-client', () => {
|
||||
const actual = jest.requireActual('openid-client');
|
||||
return {
|
||||
...actual,
|
||||
discovery: jest.fn().mockResolvedValue({
|
||||
clientId: 'fake_client_id',
|
||||
clientSecret: 'fake_client_secret',
|
||||
issuer: 'https://fake-issuer.com',
|
||||
// Add any other properties needed by the implementation
|
||||
}),
|
||||
fetchUserInfo: jest.fn().mockImplementation((config, accessToken, sub) => {
|
||||
// Only return additional properties, but don't override any claims
|
||||
return Promise.resolve({
|
||||
preferred_username: 'preferred_username',
|
||||
});
|
||||
}),
|
||||
customFetch: Symbol('customFetch'),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('openid-client/passport', () => {
|
||||
let verifyCallback;
|
||||
const mockConstructor = jest.fn((options, verify) => {
|
||||
verifyCallback = verify;
|
||||
return {
|
||||
name: 'openid',
|
||||
options,
|
||||
verify,
|
||||
};
|
||||
});
|
||||
|
||||
return {
|
||||
Strategy: mockConstructor,
|
||||
__getVerifyCallback: () => verifyCallback,
|
||||
};
|
||||
});
|
||||
|
||||
// Mock passport
|
||||
jest.mock('passport', () => ({
|
||||
use: jest.fn(),
|
||||
}));
|
||||
|
||||
const jwtDecode = require('jsonwebtoken/decode');
|
||||
|
||||
describe('setupOpenId', () => {
|
||||
// Store a reference to the verify callback once it's set up
|
||||
let verifyCallback;
|
||||
|
||||
// Helper to wrap the verify callback in a promise
|
||||
const validate = (tokenset) =>
|
||||
new Promise((resolve, reject) => {
|
||||
verifyCallback(tokenset, (err, user, details) => {
|
||||
if (err) {
|
||||
reject(err);
|
||||
} else {
|
||||
resolve({ user, details });
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
const tokenset = {
|
||||
id_token: 'fake_id_token',
|
||||
access_token: 'fake_access_token',
|
||||
claims: () => ({
|
||||
sub: '1234',
|
||||
email: 'test@example.com',
|
||||
email_verified: true,
|
||||
given_name: 'First',
|
||||
family_name: 'Last',
|
||||
name: 'My Full',
|
||||
username: 'flast',
|
||||
picture: 'https://example.com/avatar.png',
|
||||
}),
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear previous mock calls and reset implementations
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Reset environment variables needed by the strategy
|
||||
process.env.OPENID_ISSUER = 'https://fake-issuer.com';
|
||||
process.env.OPENID_CLIENT_ID = 'fake_client_id';
|
||||
process.env.OPENID_CLIENT_SECRET = 'fake_client_secret';
|
||||
process.env.DOMAIN_SERVER = 'https://example.com';
|
||||
process.env.OPENID_CALLBACK_URL = '/callback';
|
||||
process.env.OPENID_SCOPE = 'openid profile email';
|
||||
process.env.OPENID_REQUIRED_ROLE = 'requiredRole';
|
||||
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles';
|
||||
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
|
||||
delete process.env.OPENID_USERNAME_CLAIM;
|
||||
delete process.env.OPENID_NAME_CLAIM;
|
||||
delete process.env.PROXY;
|
||||
delete process.env.OPENID_USE_PKCE;
|
||||
|
||||
// Default jwtDecode mock returns a token that includes the required role.
|
||||
jwtDecode.mockReturnValue({
|
||||
roles: ['requiredRole'],
|
||||
});
|
||||
|
||||
// By default, assume that no user is found, so createUser will be called
|
||||
mockedMethods.findUser.mockResolvedValue(null);
|
||||
mockedMethods.createUser.mockImplementation(async (userData) => {
|
||||
// simulate created user with an _id property
|
||||
return { _id: 'newUserId', ...userData };
|
||||
});
|
||||
mockedMethods.updateUser.mockImplementation(async (id, userData) => {
|
||||
return { _id: id, ...userData };
|
||||
});
|
||||
|
||||
// For image download, simulate a successful response
|
||||
global.fetch = jest.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
arrayBuffer: jest.fn().mockResolvedValue(Buffer.from('fake image')),
|
||||
});
|
||||
// const { initAuth, setupOpenId } = require('@librechat/auth');
|
||||
const { setupOpenId } = require('../../packages/auth/src/strategies/openidStrategy');
|
||||
const { initAuth } = require('../../packages/auth/src/initAuth');
|
||||
const saveBufferMock = jest.fn().mockResolvedValue('/fake/path/to/avatar.png');
|
||||
await initAuth(mongoose, { enabled: false }, saveBufferMock); // mongoose: {}, fake balance config, dummy saveBuffer
|
||||
|
||||
const openidLogin = await setupOpenId({});
|
||||
|
||||
// Simulate the app's `passport.use(...)`
|
||||
passport.use('openid', openidLogin);
|
||||
|
||||
verifyCallback = require('openid-client/passport').__getVerifyCallback();
|
||||
});
|
||||
|
||||
it('should create a new user with correct username when username claim exists', async () => {
|
||||
// Arrange – our userinfo already has username 'flast'
|
||||
const userinfo = tokenset.claims();
|
||||
|
||||
// Act
|
||||
const { user } = await validate(tokenset);
|
||||
|
||||
// Assert
|
||||
expect(user.username).toBe(userinfo.username);
|
||||
expect(mockedMethods.createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username: userinfo.username,
|
||||
email: userinfo.email,
|
||||
name: `${userinfo.given_name} ${userinfo.family_name}`,
|
||||
}),
|
||||
{ enabled: false },
|
||||
true,
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use given_name as username when username claim is missing', async () => {
|
||||
// Arrange – remove username from userinfo
|
||||
const userinfo = { ...tokenset.claims() };
|
||||
delete userinfo.username;
|
||||
// Expect the username to be the given name (unchanged case)
|
||||
const expectUsername = userinfo.given_name;
|
||||
|
||||
// Act
|
||||
const { user } = await validate({ ...tokenset, claims: () => userinfo });
|
||||
|
||||
// Assert
|
||||
expect(user.username).toBe(expectUsername);
|
||||
expect(mockedMethods.createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ username: expectUsername }),
|
||||
{ enabled: false },
|
||||
true,
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use email as username when username and given_name are missing', async () => {
|
||||
// Arrange – remove username and given_name
|
||||
const userinfo = { ...tokenset.claims() };
|
||||
delete userinfo.username;
|
||||
delete userinfo.given_name;
|
||||
const expectUsername = userinfo.email;
|
||||
|
||||
// Act
|
||||
const { user } = await validate({ ...tokenset, claims: () => userinfo });
|
||||
|
||||
// Assert
|
||||
expect(user.username).toBe(expectUsername);
|
||||
expect(mockedMethods.createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ username: expectUsername }),
|
||||
{ enabled: false },
|
||||
true,
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it('should override username with OPENID_USERNAME_CLAIM when set', async () => {
|
||||
// Arrange – set OPENID_USERNAME_CLAIM so that the sub claim is used
|
||||
process.env.OPENID_USERNAME_CLAIM = 'sub';
|
||||
const userinfo = tokenset.claims();
|
||||
|
||||
// Act
|
||||
const { user } = await validate(tokenset);
|
||||
|
||||
// Assert – username should equal the sub (converted as-is)
|
||||
expect(user.username).toBe(userinfo.sub);
|
||||
expect(mockedMethods.createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ username: userinfo.sub }),
|
||||
{ enabled: false },
|
||||
true,
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it('should set the full name correctly when given_name and family_name exist', async () => {
|
||||
// Arrange
|
||||
const userinfo = tokenset.claims();
|
||||
const expectedFullName = `${userinfo.given_name} ${userinfo.family_name}`;
|
||||
|
||||
// Act
|
||||
const { user } = await validate(tokenset);
|
||||
|
||||
// Assert
|
||||
expect(user.name).toBe(expectedFullName);
|
||||
});
|
||||
|
||||
it('should override full name with OPENID_NAME_CLAIM when set', async () => {
|
||||
// Arrange – use the name claim as the full name
|
||||
process.env.OPENID_NAME_CLAIM = 'name';
|
||||
const userinfo = { ...tokenset.claims(), name: 'Custom Name' };
|
||||
|
||||
// Act
|
||||
const { user } = await validate({ ...tokenset, claims: () => userinfo });
|
||||
|
||||
// Assert
|
||||
expect(user.name).toBe('Custom Name');
|
||||
});
|
||||
|
||||
it('should update an existing user on login', async () => {
|
||||
// Arrange – simulate that a user already exists
|
||||
const existingUser = {
|
||||
_id: 'existingUserId',
|
||||
provider: 'local',
|
||||
email: tokenset.claims().email,
|
||||
openidId: '',
|
||||
username: '',
|
||||
name: '',
|
||||
};
|
||||
mockedMethods.findUser.mockImplementation(async (query) => {
|
||||
if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) {
|
||||
return existingUser;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const userinfo = tokenset.claims();
|
||||
|
||||
// Act
|
||||
await validate(tokenset);
|
||||
|
||||
// Assert – updateUser should be called and the user object updated
|
||||
expect(mockedMethods.updateUser).toHaveBeenCalledWith(
|
||||
existingUser._id,
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username: userinfo.username,
|
||||
name: `${userinfo.given_name} ${userinfo.family_name}`,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should enforce the required role and reject login if missing', async () => {
|
||||
// Arrange – simulate a token without the required role.
|
||||
jwtDecode.mockReturnValue({
|
||||
roles: ['SomeOtherRole'],
|
||||
});
|
||||
|
||||
// Act
|
||||
const { user, details } = await validate(tokenset);
|
||||
|
||||
// Assert – verify that the strategy rejects login
|
||||
expect(user).toBe(false);
|
||||
expect(details.message).toBe('You must have the "requiredRole" role to log in.');
|
||||
});
|
||||
|
||||
it('should attempt to download and save the avatar if picture is provided', async () => {
|
||||
// Act
|
||||
const { user } = await validate(tokenset);
|
||||
|
||||
// Assert – verify that download was attempted and the avatar field was set via updateUser
|
||||
expect(global.fetch).toHaveBeenCalled();
|
||||
|
||||
// Our mock getStrategyFunctions.saveBuffer returns '/fake/path/to/avatar.png'
|
||||
expect(user.avatar).toBe('/fake/path/to/avatar.png');
|
||||
});
|
||||
|
||||
it('should not attempt to download avatar if picture is not provided', async () => {
|
||||
// Arrange – remove picture
|
||||
const userinfo = { ...tokenset.claims() };
|
||||
delete userinfo.picture;
|
||||
|
||||
// Act
|
||||
await validate({ ...tokenset, claims: () => userinfo });
|
||||
|
||||
// Assert – fetch should not be called and avatar should remain undefined or empty
|
||||
expect(global.fetch).not.toHaveBeenCalled();
|
||||
// Depending on your implementation, user.avatar may be undefined or an empty string.
|
||||
});
|
||||
|
||||
it('should default to usePKCE false when OPENID_USE_PKCE is not defined', async () => {
|
||||
const OpenIDStrategy = require('openid-client/passport').Strategy;
|
||||
|
||||
delete process.env.OPENID_USE_PKCE;
|
||||
const { setupOpenId } = require('../../packages/auth/src/strategies/openidStrategy');
|
||||
await setupOpenId({});
|
||||
|
||||
const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0];
|
||||
expect(callOptions.usePKCE).toBe(false);
|
||||
expect(callOptions.params?.code_challenge_method).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
|
@ -1,438 +0,0 @@
|
|||
const passport = require('passport');
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
// --- Mocks ---
|
||||
jest.mock('fs', () => ({
|
||||
existsSync: jest.fn(),
|
||||
statSync: jest.fn(),
|
||||
readFileSync: jest.fn(),
|
||||
}));
|
||||
jest.mock('path', () => ({
|
||||
isAbsolute: jest.fn(),
|
||||
basename: jest.fn(),
|
||||
dirname: jest.fn(),
|
||||
join: jest.fn(),
|
||||
normalize: jest.fn(),
|
||||
}));
|
||||
jest.mock('@node-saml/passport-saml');
|
||||
|
||||
const mockedMethods = {
|
||||
findUser: jest.fn(),
|
||||
createUser: jest.fn(),
|
||||
updateUser: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const actual = jest.requireActual('@librechat/data-schemas');
|
||||
return {
|
||||
...actual,
|
||||
createMethods: jest.fn(() => mockedMethods),
|
||||
};
|
||||
});
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
config: {
|
||||
registration: {
|
||||
socialLogins: ['saml'],
|
||||
},
|
||||
},
|
||||
getBalanceConfig: jest.fn().mockResolvedValue({
|
||||
tokenCredits: 1000,
|
||||
startingBalance: 1000,
|
||||
}),
|
||||
}));
|
||||
jest.mock('~/server/services/Config/EndpointService', () => ({
|
||||
config: {},
|
||||
}));
|
||||
jest.mock('~/server/utils', () => ({
|
||||
isEnabled: jest.fn(() => false),
|
||||
isUserProvided: jest.fn(() => false),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/utils/crypto', () => ({
|
||||
hashToken: jest.fn().mockResolvedValue('hashed-token'),
|
||||
}));
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
// To capture the verify callback from the strategy, we grab it from the mock constructor
|
||||
describe('getCertificateContent', () => {
|
||||
const { getCertificateContent } = require('../../packages/auth/src/strategies/samlStrategy');
|
||||
// const { getCertificateContent } = require('@librechat/auth');
|
||||
const certWithHeader = `-----BEGIN CERTIFICATE-----
|
||||
MIIDazCCAlOgAwIBAgIUKhXaFJGJJPx466rlwYORIsqCq7MwDQYJKoZIhvcNAQEL
|
||||
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
|
||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNTAzMDQwODUxNTJaFw0yNjAz
|
||||
MDQwODUxNTJaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
|
||||
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
|
||||
AQUAA4IBDwAwggEKAoIBAQCWP09NZg0xaRiLpNygCVgV3M+4RFW2S0c5X/fg/uFT
|
||||
O5MfaVYzG5GxzhXzWRB8RtNPsxX/nlbPsoUroeHbz+SABkOsNEv6JuKRH4VXRH34
|
||||
VzjazVkPAwj+N4WqsC/Wo4EGGpKIGeGi8Zed4yvMqoTyE3mrS19fY0nMHT62wUwS
|
||||
GMm2pAQdAQePZ9WY7A5XOA1IoxW2Zh2Oxaf1p59epBkZDhoxSMu8GoSkvK27Km4A
|
||||
4UXftzdg/wHNPrNirmcYouioHdmrOtYxPjrhUBQ74AmE1/QK45B6wEgirKH1A1AW
|
||||
6C+ApLwpBMvy9+8Gbyvc8G18W3CjdEVKmAeWb9JUedSXAgMBAAGjUzBRMB0GA1Ud
|
||||
DgQWBBRxpaqBx8VDLLc8IkHATujj8IOs6jAfBgNVHSMEGDAWgBRxpaqBx8VDLLc8
|
||||
IkHATujj8IOs6jAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBc
|
||||
Puk6i+yowwGccB3LhfxZ+Fz6s6/Lfx6bP/Hy4NYOxmx2/awGBgyfp1tmotjaS9Cf
|
||||
FWd67LuEru4TYtz12RNMDBF5ypcEfibvb3I8O6igOSQX/Jl5D2pMChesZxhmCift
|
||||
Qp09T41MA8PmHf1G9oMG0A3ZnjKDG5ebaJNRFImJhMHsgh/TP7V3uZy7YHTgopKX
|
||||
Hv63V3Uo3Oihav29Q7urwmf7Ly7X7J2WE86/w3vRHi5dhaWWqEqxmnAXl+H+sG4V
|
||||
meeVRI332bg1Nuy8KnnX8v3ZeJzMBkAhzvSr6Ri96R0/Un/oEFwVC5jDTq8sXVn6
|
||||
u7wlOSk+oFzDIO/UILIA
|
||||
-----END CERTIFICATE-----`;
|
||||
|
||||
const certWithoutHeader = certWithHeader
|
||||
.replace(/-----BEGIN CERTIFICATE-----/g, '')
|
||||
.replace(/-----END CERTIFICATE-----/g, '')
|
||||
.replace(/\s+/g, '');
|
||||
|
||||
it('should throw an error if SAML_CERT is not set', () => {
|
||||
process.env.SAML_CERT;
|
||||
expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow(
|
||||
'Invalid input: SAML_CERT must be a string.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if SAML_CERT is empty', () => {
|
||||
process.env.SAML_CERT = '';
|
||||
expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow(
|
||||
'Invalid cert: SAML_CERT must be a valid file path or certificate string.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should load cert from an environment variable if it is a single-line string(with header)', () => {
|
||||
process.env.SAML_CERT = certWithHeader;
|
||||
|
||||
const actual = getCertificateContent(process.env.SAML_CERT);
|
||||
expect(actual).toBe(certWithHeader);
|
||||
});
|
||||
|
||||
it('should load cert from an environment variable if it is a single-line string(with no header)', () => {
|
||||
process.env.SAML_CERT = certWithoutHeader;
|
||||
|
||||
const actual = getCertificateContent(process.env.SAML_CERT);
|
||||
expect(actual).toBe(certWithoutHeader);
|
||||
});
|
||||
|
||||
it('should throw an error if SAML_CERT is a single-line string (with header, no newline characters)', () => {
|
||||
process.env.SAML_CERT = certWithHeader.replace(/\n/g, '');
|
||||
expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow(
|
||||
'Invalid cert: SAML_CERT must be a valid file path or certificate string.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should load cert from a relative file path if SAML_CERT is valid', () => {
|
||||
process.env.SAML_CERT = 'test.pem';
|
||||
const resolvedPath = '/absolute/path/to/test.pem';
|
||||
|
||||
path.isAbsolute.mockReturnValue(false);
|
||||
path.join.mockReturnValue(resolvedPath);
|
||||
path.normalize.mockReturnValue(resolvedPath);
|
||||
|
||||
fs.existsSync.mockReturnValue(true);
|
||||
fs.statSync.mockReturnValue({ isFile: () => true });
|
||||
fs.readFileSync.mockReturnValue(certWithHeader);
|
||||
|
||||
const actual = getCertificateContent(process.env.SAML_CERT);
|
||||
console.log(actual);
|
||||
expect(actual).toBe(certWithHeader);
|
||||
});
|
||||
|
||||
it('should load cert from an absolute file path if SAML_CERT is valid', () => {
|
||||
process.env.SAML_CERT = '/absolute/path/to/test.pem';
|
||||
|
||||
path.isAbsolute.mockReturnValue(true);
|
||||
path.normalize.mockReturnValue(process.env.SAML_CERT);
|
||||
|
||||
fs.existsSync.mockReturnValue(true);
|
||||
fs.statSync.mockReturnValue({ isFile: () => true });
|
||||
fs.readFileSync.mockReturnValue(certWithHeader);
|
||||
|
||||
const actual = getCertificateContent(process.env.SAML_CERT);
|
||||
expect(actual).toBe(certWithHeader);
|
||||
});
|
||||
|
||||
it('should throw an error if the file does not exist', () => {
|
||||
process.env.SAML_CERT = 'missing.pem';
|
||||
const resolvedPath = '/absolute/path/to/missing.pem';
|
||||
|
||||
path.isAbsolute.mockReturnValue(false);
|
||||
path.join.mockReturnValue(resolvedPath);
|
||||
path.normalize.mockReturnValue(resolvedPath);
|
||||
|
||||
fs.existsSync.mockReturnValue(false);
|
||||
|
||||
expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow(
|
||||
'Invalid cert: SAML_CERT must be a valid file path or certificate string.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if the file is not readable', () => {
|
||||
process.env.SAML_CERT = 'unreadable.pem';
|
||||
const resolvedPath = '/absolute/path/to/unreadable.pem';
|
||||
|
||||
path.isAbsolute.mockReturnValue(false);
|
||||
path.join.mockReturnValue(resolvedPath);
|
||||
path.normalize.mockReturnValue(resolvedPath);
|
||||
|
||||
fs.existsSync.mockReturnValue(true);
|
||||
fs.statSync.mockReturnValue({ isFile: () => true });
|
||||
fs.readFileSync.mockImplementation(() => {
|
||||
throw new Error('Permission denied');
|
||||
});
|
||||
|
||||
expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow(
|
||||
'Error reading certificate file: Permission denied',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setupSaml', () => {
|
||||
let verifyCallback;
|
||||
|
||||
// Helper to wrap the verify callback in a promise
|
||||
const validate = (profile) =>
|
||||
new Promise((resolve, reject) => {
|
||||
verifyCallback(profile, (err, user, details) => {
|
||||
if (err) {
|
||||
reject(err);
|
||||
} else {
|
||||
resolve({ user, details });
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
const baseProfile = {
|
||||
nameID: 'saml-1234',
|
||||
email: 'test@example.com',
|
||||
given_name: 'First',
|
||||
family_name: 'Last',
|
||||
name: 'My Full Name',
|
||||
username: 'flast',
|
||||
picture: 'https://example.com/avatar.png',
|
||||
custom_name: 'custom',
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Configure mocks
|
||||
mockedMethods.findUser.mockResolvedValue(null);
|
||||
mockedMethods.createUser.mockImplementation(async (userData) => ({
|
||||
_id: 'mock-user-id',
|
||||
...userData,
|
||||
}));
|
||||
mockedMethods.updateUser.mockImplementation(async (id, userData) => ({
|
||||
_id: id,
|
||||
...userData,
|
||||
}));
|
||||
|
||||
const cert = `
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDazCCAlOgAwIBAgIUKhXaFJGJJPx466rlwYORIsqCq7MwDQYJKoZIhvcNAQEL
|
||||
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
|
||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNTAzMDQwODUxNTJaFw0yNjAz
|
||||
MDQwODUxNTJaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
|
||||
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
|
||||
AQUAA4IBDwAwggEKAoIBAQCWP09NZg0xaRiLpNygCVgV3M+4RFW2S0c5X/fg/uFT
|
||||
O5MfaVYzG5GxzhXzWRB8RtNPsxX/nlbPsoUroeHbz+SABkOsNEv6JuKRH4VXRH34
|
||||
VzjazVkPAwj+N4WqsC/Wo4EGGpKIGeGi8Zed4yvMqoTyE3mrS19fY0nMHT62wUwS
|
||||
GMm2pAQdAQePZ9WY7A5XOA1IoxW2Zh2Oxaf1p59epBkZDhoxSMu8GoSkvK27Km4A
|
||||
4UXftzdg/wHNPrNirmcYouioHdmrOtYxPjrhUBQ74AmE1/QK45B6wEgirKH1A1AW
|
||||
6C+ApLwpBMvy9+8Gbyvc8G18W3CjdEVKmAeWb9JUedSXAgMBAAGjUzBRMB0GA1Ud
|
||||
DgQWBBRxpaqBx8VDLLc8IkHATujj8IOs6jAfBgNVHSMEGDAWgBRxpaqBx8VDLLc8
|
||||
IkHATujj8IOs6jAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBc
|
||||
Puk6i+yowwGccB3LhfxZ+Fz6s6/Lfx6bP/Hy4NYOxmx2/awGBgyfp1tmotjaS9Cf
|
||||
FWd67LuEru4TYtz12RNMDBF5ypcEfibvb3I8O6igOSQX/Jl5D2pMChesZxhmCift
|
||||
Qp09T41MA8PmHf1G9oMG0A3ZnjKDG5ebaJNRFImJhMHsgh/TP7V3uZy7YHTgopKX
|
||||
Hv63V3Uo3Oihav29Q7urwmf7Ly7X7J2WE86/w3vRHi5dhaWWqEqxmnAXl+H+sG4V
|
||||
meeVRI332bg1Nuy8KnnX8v3ZeJzMBkAhzvSr6Ri96R0/Un/oEFwVC5jDTq8sXVn6
|
||||
u7wlOSk+oFzDIO/UILIA
|
||||
-----END CERTIFICATE-----
|
||||
`;
|
||||
|
||||
// Reset environment variables
|
||||
process.env.SAML_ENTRY_POINT = 'https://example.com/saml';
|
||||
process.env.SAML_ISSUER = 'saml-issuer';
|
||||
process.env.SAML_CERT = cert;
|
||||
process.env.SAML_CALLBACK_URL = '/oauth/saml/callback';
|
||||
delete process.env.SAML_EMAIL_CLAIM;
|
||||
delete process.env.SAML_USERNAME_CLAIM;
|
||||
delete process.env.SAML_GIVEN_NAME_CLAIM;
|
||||
delete process.env.SAML_FAMILY_NAME_CLAIM;
|
||||
delete process.env.SAML_PICTURE_CLAIM;
|
||||
delete process.env.SAML_NAME_CLAIM;
|
||||
|
||||
// For image download, simulate a successful response
|
||||
global.fetch = jest.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
arrayBuffer: jest.fn().mockResolvedValue(Buffer.from('fake image')),
|
||||
});
|
||||
|
||||
const { samlLogin } = require('../../packages/auth/src/strategies/samlStrategy');
|
||||
const { initAuth } = require('../../packages/auth/src/initAuth');
|
||||
const saveBufferMock = jest.fn().mockResolvedValue('/fake/path/to/avatar.png');
|
||||
await initAuth(mongoose, { enabled: false }, saveBufferMock);
|
||||
|
||||
// Simulate the app's `passport.use(...)`
|
||||
const SamlStrategy = samlLogin();
|
||||
passport.use('saml', SamlStrategy);
|
||||
|
||||
console.log('SamlStrategy', SamlStrategy);
|
||||
verifyCallback = SamlStrategy._signonVerify;
|
||||
console.log('----', verifyCallback);
|
||||
});
|
||||
|
||||
it('should create a new user with correct username when username claim exists', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.username).toBe(profile.username);
|
||||
expect(user.provider).toBe('saml');
|
||||
expect(user.samlId).toBe(profile.nameID);
|
||||
expect(user.email).toBe(profile.email);
|
||||
expect(user.name).toBe(`${profile.given_name} ${profile.family_name}`);
|
||||
});
|
||||
|
||||
it('should use given_name as username when username claim is missing', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
delete profile.username;
|
||||
const expectUsername = profile.given_name;
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.username).toBe(expectUsername);
|
||||
expect(user.provider).toBe('saml');
|
||||
});
|
||||
|
||||
it('should use email as username when username and given_name are missing', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
delete profile.username;
|
||||
delete profile.given_name;
|
||||
const expectUsername = profile.email;
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.username).toBe(expectUsername);
|
||||
expect(user.provider).toBe('saml');
|
||||
});
|
||||
|
||||
it('should override username with SAML_USERNAME_CLAIM when set', async () => {
|
||||
process.env.SAML_USERNAME_CLAIM = 'nameID';
|
||||
const profile = { ...baseProfile };
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.username).toBe(profile.nameID);
|
||||
expect(user.provider).toBe('saml');
|
||||
});
|
||||
|
||||
it('should set the full name correctly when given_name and family_name exist', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
const expectedFullName = `${profile.given_name} ${profile.family_name}`;
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.name).toBe(expectedFullName);
|
||||
});
|
||||
|
||||
it('should set the full name correctly when given_name exist', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
delete profile.family_name;
|
||||
const expectedFullName = profile.given_name;
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.name).toBe(expectedFullName);
|
||||
});
|
||||
|
||||
it('should set the full name correctly when family_name exist', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
delete profile.given_name;
|
||||
const expectedFullName = profile.family_name;
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.name).toBe(expectedFullName);
|
||||
});
|
||||
|
||||
it('should set the full name correctly when username exist', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
delete profile.family_name;
|
||||
delete profile.given_name;
|
||||
const expectedFullName = profile.username;
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.name).toBe(expectedFullName);
|
||||
});
|
||||
|
||||
it('should set the full name correctly when email only exist', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
delete profile.family_name;
|
||||
delete profile.given_name;
|
||||
delete profile.username;
|
||||
const expectedFullName = profile.email;
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.name).toBe(expectedFullName);
|
||||
});
|
||||
|
||||
it('should set the full name correctly with SAML_NAME_CLAIM when set', async () => {
|
||||
process.env.SAML_NAME_CLAIM = 'custom_name';
|
||||
const profile = { ...baseProfile };
|
||||
const expectedFullName = profile.custom_name;
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.name).toBe(expectedFullName);
|
||||
});
|
||||
|
||||
it('should update an existing user on login', async () => {
|
||||
// Set up findUser to return an existing user
|
||||
const { findUser } = require('~/models');
|
||||
const existingUser = {
|
||||
_id: 'existing-user-id',
|
||||
provider: 'local',
|
||||
email: baseProfile.email,
|
||||
samlId: '',
|
||||
username: 'oldusername',
|
||||
name: 'Old Name',
|
||||
};
|
||||
findUser.mockResolvedValue(existingUser);
|
||||
|
||||
const profile = { ...baseProfile };
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(user.provider).toBe('saml');
|
||||
expect(user.samlId).toBe(baseProfile.nameID);
|
||||
expect(user.username).toBe(baseProfile.username);
|
||||
expect(user.name).toBe(`${baseProfile.given_name} ${baseProfile.family_name}`);
|
||||
expect(user.email).toBe(baseProfile.email);
|
||||
});
|
||||
|
||||
it('should attempt to download and save the avatar if picture is provided', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
|
||||
const { user } = await validate(profile);
|
||||
|
||||
expect(global.fetch).toHaveBeenCalled();
|
||||
expect(user.avatar).toBe('/fake/path/to/avatar.png');
|
||||
});
|
||||
|
||||
it('should not attempt to download avatar if picture is not provided', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
delete profile.picture;
|
||||
|
||||
await validate(profile);
|
||||
|
||||
expect(global.fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
|
@ -1,456 +0,0 @@
|
|||
// file deepcode ignore NoHardcodedPasswords: No hard-coded passwords in tests
|
||||
const { errorsToString } = require('librechat-data-provider');
|
||||
const { loginSchema, registerSchema } = require('@librechat/auth');
|
||||
|
||||
describe('Zod Schemas', () => {
|
||||
describe('loginSchema', () => {
|
||||
it('should validate a correct login object', () => {
|
||||
const result = loginSchema.safeParse({
|
||||
email: 'test@example.com',
|
||||
password: 'password123',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should invalidate an incorrect email', () => {
|
||||
const result = loginSchema.safeParse({
|
||||
email: 'testexample.com',
|
||||
password: 'password123',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should invalidate a short password', () => {
|
||||
const result = loginSchema.safeParse({
|
||||
email: 'test@example.com',
|
||||
password: 'pass',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle email with unusual characters', () => {
|
||||
const emails = ['test+alias@example.com', 'test@subdomain.example.co.uk'];
|
||||
emails.forEach((email) => {
|
||||
const result = loginSchema.safeParse({
|
||||
email,
|
||||
password: 'password123',
|
||||
});
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should invalidate email without a domain', () => {
|
||||
const result = loginSchema.safeParse({
|
||||
email: 'test@.com',
|
||||
password: 'password123',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should invalidate password with only spaces', () => {
|
||||
const result = loginSchema.safeParse({
|
||||
email: 'test@example.com',
|
||||
password: ' ',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should invalidate password that is too long', () => {
|
||||
const result = loginSchema.safeParse({
|
||||
email: 'test@example.com',
|
||||
password: 'a'.repeat(129),
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should invalidate empty email or password', () => {
|
||||
const result = loginSchema.safeParse({
|
||||
email: '',
|
||||
password: '',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('registerSchema', () => {
|
||||
it('should validate a correct register object', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow the username to be omitted', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should invalidate a short name', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'Jo',
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle empty username by transforming to null', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: '',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.data.username).toBe(null);
|
||||
});
|
||||
|
||||
it('should handle name with special characters', () => {
|
||||
const names = ['Jöhn Dœ', 'John <Doe>'];
|
||||
names.forEach((name) => {
|
||||
const result = registerSchema.safeParse({
|
||||
name,
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle username with special characters', () => {
|
||||
const usernames = ['john.doe@', 'john..doe'];
|
||||
usernames.forEach((username) => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username,
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should invalidate mismatched password and confirm_password', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password124',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle email without a TLD', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: 'john_doe',
|
||||
email: 'john@domain',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle email with multiple @ symbols', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: 'john_doe',
|
||||
email: 'john@domain@com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle name that is too long', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'a'.repeat(81),
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle username that is too long', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: 'a'.repeat(81),
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle password or confirm_password that is too long', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: 'a'.repeat(129),
|
||||
confirm_password: 'a'.repeat(129),
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle password or confirm_password that is just spaces', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: ' ',
|
||||
confirm_password: ' ',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle null values for fields', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: null,
|
||||
username: null,
|
||||
email: null,
|
||||
password: null,
|
||||
confirm_password: null,
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle undefined values for fields', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: undefined,
|
||||
username: undefined,
|
||||
email: undefined,
|
||||
password: undefined,
|
||||
confirm_password: undefined,
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle extra fields not defined in the schema', () => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
extraField: "I shouldn't be here",
|
||||
});
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle username with special characters from various languages', () => {
|
||||
const usernames = [
|
||||
// General
|
||||
'éèäöü',
|
||||
|
||||
// German
|
||||
'Jöhn.Döe@',
|
||||
'Jöhn_Ü',
|
||||
'Jöhnß',
|
||||
|
||||
// French
|
||||
'Jéan-Piérre',
|
||||
'Élève',
|
||||
'Fiançée',
|
||||
'Mère',
|
||||
|
||||
// Spanish
|
||||
'Niño',
|
||||
'Señor',
|
||||
'Muñoz',
|
||||
|
||||
// Portuguese
|
||||
'João',
|
||||
'Coração',
|
||||
'Pão',
|
||||
|
||||
// Italian
|
||||
'Pietro',
|
||||
'Bambino',
|
||||
'Forlì',
|
||||
|
||||
// Romanian
|
||||
'Mâncare',
|
||||
'Școală',
|
||||
'Țară',
|
||||
|
||||
// Catalan
|
||||
'Niç',
|
||||
'Màquina',
|
||||
'Çap',
|
||||
|
||||
// Swedish
|
||||
'Fjärran',
|
||||
'Skål',
|
||||
'Öland',
|
||||
|
||||
// Norwegian
|
||||
'Blåbær',
|
||||
'Fjord',
|
||||
'Årstid',
|
||||
|
||||
// Danish
|
||||
'Flød',
|
||||
'Søster',
|
||||
'Århus',
|
||||
|
||||
// Icelandic
|
||||
'Þór',
|
||||
'Ætt',
|
||||
'Öx',
|
||||
|
||||
// Turkish
|
||||
'Şehir',
|
||||
'Çocuk',
|
||||
'Gözlük',
|
||||
|
||||
// Polish
|
||||
'Łódź',
|
||||
'Część',
|
||||
'Świat',
|
||||
|
||||
// Czech
|
||||
'Čaj',
|
||||
'Řeka',
|
||||
'Život',
|
||||
|
||||
// Slovak
|
||||
'Kočka',
|
||||
'Ľudia',
|
||||
'Žaba',
|
||||
|
||||
// Croatian
|
||||
'Čovjek',
|
||||
'Šuma',
|
||||
'Žaba',
|
||||
|
||||
// Hungarian
|
||||
'Tűz',
|
||||
'Ősz',
|
||||
'Ünnep',
|
||||
|
||||
// Finnish
|
||||
'Mäki',
|
||||
'Yö',
|
||||
'Äiti',
|
||||
|
||||
// Estonian
|
||||
'Tänav',
|
||||
'Öö',
|
||||
'Ülikool',
|
||||
|
||||
// Latvian
|
||||
'Ēka',
|
||||
'Ūdens',
|
||||
'Čempions',
|
||||
|
||||
// Lithuanian
|
||||
'Ūsas',
|
||||
'Ąžuolas',
|
||||
'Čia',
|
||||
|
||||
// Dutch
|
||||
'Maïs',
|
||||
'Geërfd',
|
||||
'Coördinatie',
|
||||
];
|
||||
|
||||
const failingUsernames = usernames.reduce((acc, username) => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username,
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
acc.push({ username, error: result.error });
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
if (failingUsernames.length > 0) {
|
||||
console.log('Failing Usernames:', failingUsernames);
|
||||
}
|
||||
expect(failingUsernames).toEqual([]);
|
||||
});
|
||||
|
||||
it('should reject invalid usernames', () => {
|
||||
const invalidUsernames = [
|
||||
'john{doe}', // Contains `{` and `}`
|
||||
'j', // Only one character
|
||||
'a'.repeat(81), // More than 80 characters
|
||||
"' OR '1'='1'; --", // SQL Injection
|
||||
'{$ne: null}', // MongoDB Injection
|
||||
'<script>alert("XSS")</script>', // Basic XSS
|
||||
'"><script>alert("XSS")</script>', // XSS breaking out of an attribute
|
||||
'"><img src=x onerror=alert("XSS")>', // XSS using an image tag
|
||||
];
|
||||
|
||||
const passingUsernames = [];
|
||||
const failingUsernames = invalidUsernames.reduce((acc, username) => {
|
||||
const result = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
username,
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
acc.push({ username, error: result.error });
|
||||
}
|
||||
|
||||
if (result.success) {
|
||||
passingUsernames.push({ username });
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
expect(failingUsernames.length).toEqual(invalidUsernames.length); // They should match since all invalidUsernames should fail.
|
||||
});
|
||||
});
|
||||
|
||||
describe('errorsToString', () => {
|
||||
it('should convert errors to string', () => {
|
||||
const { error } = registerSchema.safeParse({
|
||||
name: 'Jo',
|
||||
username: 'john_doe',
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
});
|
||||
|
||||
const result = errorsToString(error.errors);
|
||||
expect(result).toBe('name: String must contain at least 3 character(s)');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -41,10 +41,7 @@ jest.mock('winston-daily-rotate-file', () => {
|
|||
});
|
||||
|
||||
jest.mock('~/config', () => {
|
||||
const actualModule = jest.requireActual('~/config');
|
||||
return {
|
||||
sendEvent: actualModule.sendEvent,
|
||||
createAxiosInstance: actualModule.createAxiosInstance,
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
|
|
|
|||
|
|
@ -1073,7 +1073,7 @@
|
|||
|
||||
/**
|
||||
* @exports MCPServers
|
||||
* @typedef {import('librechat-mcp').MCPServers} MCPServers
|
||||
* @typedef {import('@librechat/api').MCPServers} MCPServers
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
|
|
@ -1085,31 +1085,31 @@
|
|||
|
||||
/**
|
||||
* @exports MCPManager
|
||||
* @typedef {import('librechat-mcp').MCPManager} MCPManager
|
||||
* @typedef {import('@librechat/api').MCPManager} MCPManager
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports FlowStateManager
|
||||
* @typedef {import('librechat-mcp').FlowStateManager} FlowStateManager
|
||||
* @typedef {import('@librechat/api').FlowStateManager} FlowStateManager
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports LCAvailableTools
|
||||
* @typedef {import('librechat-mcp').LCAvailableTools} LCAvailableTools
|
||||
* @typedef {import('@librechat/api').LCAvailableTools} LCAvailableTools
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports LCTool
|
||||
* @typedef {import('librechat-mcp').LCTool} LCTool
|
||||
* @typedef {import('@librechat/api').LCTool} LCTool
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports FormattedContent
|
||||
* @typedef {import('librechat-mcp').FormattedContent} FormattedContent
|
||||
* @typedef {import('@librechat/api').FormattedContent} FormattedContent
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
|
|
@ -1232,7 +1232,7 @@
|
|||
* @typedef {Object} AgentClientOptions
|
||||
* @property {Agent} agent - The agent configuration object
|
||||
* @property {string} endpoint - The endpoint identifier for the agent
|
||||
* @property {Object} req - The request object
|
||||
* @property {ServerRequest} req - The request object
|
||||
* @property {string} [name] - The username
|
||||
* @property {string} [modelLabel] - The label for the model being used
|
||||
* @property {number} [maxContextTokens] - Maximum number of tokens allowed in context
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Logs Axios errors based on the error object and a custom message.
|
||||
*
|
||||
* @param {Object} options - The options object.
|
||||
* @param {string} options.message - The custom message to be logged.
|
||||
* @param {import('axios').AxiosError} options.error - The Axios error object.
|
||||
* @returns {string} The log message.
|
||||
*/
|
||||
const logAxiosError = ({ message, error }) => {
|
||||
let logMessage = message;
|
||||
try {
|
||||
const stack = error.stack || 'No stack trace available';
|
||||
|
||||
if (error.response?.status) {
|
||||
const { status, headers, data } = error.response;
|
||||
logMessage = `${message} The server responded with status ${status}: ${error.message}`;
|
||||
logger.error(logMessage, {
|
||||
status,
|
||||
headers,
|
||||
data,
|
||||
stack,
|
||||
});
|
||||
} else if (error.request) {
|
||||
const { method, url } = error.config || {};
|
||||
logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`;
|
||||
logger.error(logMessage, {
|
||||
requestInfo: { method, url },
|
||||
stack,
|
||||
});
|
||||
} else if (error?.message?.includes("Cannot read properties of undefined (reading 'status')")) {
|
||||
logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`;
|
||||
logger.error(logMessage, { stack });
|
||||
} else {
|
||||
logMessage = `${message} An error occurred while setting up the request: ${error.message}`;
|
||||
logger.error(logMessage, { stack });
|
||||
}
|
||||
} catch (err) {
|
||||
logMessage = `Error in logAxiosError: ${err.message}`;
|
||||
logger.error(logMessage, { stack: err.stack || 'No stack trace available' });
|
||||
}
|
||||
return logMessage;
|
||||
};
|
||||
|
||||
module.exports = { logAxiosError };
|
||||
|
|
@ -1,105 +0,0 @@
|
|||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
/**
|
||||
* Sanitizes the model name to be used in the URL by removing or replacing disallowed characters.
|
||||
* @param {string} modelName - The model name to be sanitized.
|
||||
* @returns {string} The sanitized model name.
|
||||
*/
|
||||
const sanitizeModelName = (modelName) => {
|
||||
// Replace periods with empty strings and other disallowed characters as needed.
|
||||
return modelName.replace(/\./g, '');
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates the Azure OpenAI API endpoint URL.
|
||||
* @param {Object} params - The parameters object.
|
||||
* @param {string} params.azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
|
||||
* @param {string} params.azureOpenAIApiDeploymentName - The Azure OpenAI API deployment name.
|
||||
* @returns {string} The complete endpoint URL for the Azure OpenAI API.
|
||||
*/
|
||||
const genAzureEndpoint = ({ azureOpenAIApiInstanceName, azureOpenAIApiDeploymentName }) => {
|
||||
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates the Azure OpenAI API chat completion endpoint URL with the API version.
|
||||
* If both deploymentName and modelName are provided, modelName takes precedence.
|
||||
* @param {Object} AzureConfig - The Azure configuration object.
|
||||
* @param {string} AzureConfig.azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
|
||||
* @param {string} [AzureConfig.azureOpenAIApiDeploymentName] - The Azure OpenAI API deployment name (optional).
|
||||
* @param {string} AzureConfig.azureOpenAIApiVersion - The Azure OpenAI API version.
|
||||
* @param {string} [modelName] - The model name to be included in the deployment name (optional).
|
||||
* @param {Object} [client] - The API Client class for optionally setting properties (optional).
|
||||
* @returns {string} The complete chat completion endpoint URL for the Azure OpenAI API.
|
||||
* @throws {Error} If neither azureOpenAIApiDeploymentName nor modelName is provided.
|
||||
*/
|
||||
const genAzureChatCompletion = (
|
||||
{ azureOpenAIApiInstanceName, azureOpenAIApiDeploymentName, azureOpenAIApiVersion },
|
||||
modelName,
|
||||
client,
|
||||
) => {
|
||||
// Determine the deployment segment of the URL based on provided modelName or azureOpenAIApiDeploymentName
|
||||
let deploymentSegment;
|
||||
if (isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME) && modelName) {
|
||||
const sanitizedModelName = sanitizeModelName(modelName);
|
||||
deploymentSegment = `${sanitizedModelName}`;
|
||||
client &&
|
||||
typeof client === 'object' &&
|
||||
(client.azure.azureOpenAIApiDeploymentName = sanitizedModelName);
|
||||
} else if (azureOpenAIApiDeploymentName) {
|
||||
deploymentSegment = azureOpenAIApiDeploymentName;
|
||||
} else if (!process.env.AZURE_OPENAI_BASEURL) {
|
||||
throw new Error(
|
||||
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
|
||||
);
|
||||
}
|
||||
|
||||
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${deploymentSegment}/chat/completions?api-version=${azureOpenAIApiVersion}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves the Azure OpenAI API credentials from environment variables.
|
||||
* @returns {AzureOptions} An object containing the Azure OpenAI API credentials.
|
||||
*/
|
||||
const getAzureCredentials = () => {
|
||||
return {
|
||||
azureOpenAIApiKey: process.env.AZURE_API_KEY ?? process.env.AZURE_OPENAI_API_KEY,
|
||||
azureOpenAIApiInstanceName: process.env.AZURE_OPENAI_API_INSTANCE_NAME,
|
||||
azureOpenAIApiDeploymentName: process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME,
|
||||
azureOpenAIApiVersion: process.env.AZURE_OPENAI_API_VERSION,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Constructs a URL by replacing placeholders in the baseURL with values from the azure object.
|
||||
* It specifically looks for '${INSTANCE_NAME}' and '${DEPLOYMENT_NAME}' within the baseURL and replaces
|
||||
* them with 'azureOpenAIApiInstanceName' and 'azureOpenAIApiDeploymentName' from the azure object.
|
||||
* If the respective azure property is not provided, the placeholder is replaced with an empty string.
|
||||
*
|
||||
* @param {Object} params - The parameters object.
|
||||
* @param {string} params.baseURL - The baseURL to inspect for replacement placeholders.
|
||||
* @param {AzureOptions} params.azureOptions - The azure options object containing the instance and deployment names.
|
||||
* @returns {string} The complete baseURL with credentials injected for the Azure OpenAI API.
|
||||
*/
|
||||
function constructAzureURL({ baseURL, azureOptions }) {
|
||||
let finalURL = baseURL;
|
||||
|
||||
// Replace INSTANCE_NAME and DEPLOYMENT_NAME placeholders with actual values if available
|
||||
if (azureOptions) {
|
||||
finalURL = finalURL.replace('${INSTANCE_NAME}', azureOptions.azureOpenAIApiInstanceName ?? '');
|
||||
finalURL = finalURL.replace(
|
||||
'${DEPLOYMENT_NAME}',
|
||||
azureOptions.azureOpenAIApiDeploymentName ?? '',
|
||||
);
|
||||
}
|
||||
|
||||
return finalURL;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
sanitizeModelName,
|
||||
genAzureEndpoint,
|
||||
genAzureChatCompletion,
|
||||
getAzureCredentials,
|
||||
constructAzureURL,
|
||||
};
|
||||
|
|
@ -1,268 +0,0 @@
|
|||
const {
|
||||
sanitizeModelName,
|
||||
genAzureEndpoint,
|
||||
genAzureChatCompletion,
|
||||
getAzureCredentials,
|
||||
constructAzureURL,
|
||||
} = require('./azureUtils');
|
||||
|
||||
describe('sanitizeModelName', () => {
|
||||
test('removes periods from the model name', () => {
|
||||
const sanitized = sanitizeModelName('model.name');
|
||||
expect(sanitized).toBe('modelname');
|
||||
});
|
||||
|
||||
test('leaves model name unchanged if no periods are present', () => {
|
||||
const sanitized = sanitizeModelName('modelname');
|
||||
expect(sanitized).toBe('modelname');
|
||||
});
|
||||
});
|
||||
|
||||
describe('genAzureEndpoint', () => {
|
||||
test('generates correct endpoint URL', () => {
|
||||
const url = genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
});
|
||||
expect(url).toBe('https://instanceName.openai.azure.com/openai/deployments/deploymentName');
|
||||
});
|
||||
});
|
||||
|
||||
describe('genAzureChatCompletion', () => {
|
||||
// Test with both deployment name and model name provided
|
||||
test('prefers model name over deployment name when both are provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with only deployment name provided
|
||||
test('uses deployment name when model name is not provided', () => {
|
||||
const url = genAzureChatCompletion({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with only model name provided
|
||||
test('uses model name when deployment name is not provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with neither deployment name nor model name provided
|
||||
test('throws error if neither deployment name nor model name is provided', () => {
|
||||
expect(() => {
|
||||
genAzureChatCompletion({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
}).toThrow(
|
||||
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with feature disabled but model name provided
|
||||
test('ignores model name and uses deployment name when feature is disabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'false';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with sanitized model name
|
||||
test('sanitizes model name when used in URL', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'model.name',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelname/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with client parameter and model name
|
||||
test('updates client with sanitized model name when provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const clientMock = { azure: {} };
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'model.name',
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelname/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBe('modelname');
|
||||
});
|
||||
|
||||
// Test with client parameter but without model name
|
||||
test('does not update client when model name is not provided', () => {
|
||||
const clientMock = { azure: {} };
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
undefined,
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBeUndefined();
|
||||
});
|
||||
|
||||
// Test with client parameter and deployment name when feature is disabled
|
||||
test('does not update client when feature is disabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'false';
|
||||
const clientMock = { azure: {} };
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBeUndefined();
|
||||
});
|
||||
|
||||
// Reset environment variable after tests
|
||||
afterEach(() => {
|
||||
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAzureCredentials', () => {
|
||||
beforeEach(() => {
|
||||
process.env.AZURE_API_KEY = 'testApiKey';
|
||||
process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'instanceName';
|
||||
process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'deploymentName';
|
||||
process.env.AZURE_OPENAI_API_VERSION = 'v1';
|
||||
});
|
||||
|
||||
test('retrieves Azure OpenAI API credentials from environment variables', () => {
|
||||
const credentials = getAzureCredentials();
|
||||
expect(credentials).toEqual({
|
||||
azureOpenAIApiKey: 'testApiKey',
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('constructAzureURL', () => {
|
||||
test('replaces both placeholders when both properties are provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/instance1/deployment1');
|
||||
});
|
||||
|
||||
test('replaces only INSTANCE_NAME when only azureOpenAIApiInstanceName is provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance2',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/instance2/');
|
||||
});
|
||||
|
||||
test('replaces only DEPLOYMENT_NAME when only azureOpenAIApiDeploymentName is provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiDeploymentName: 'deployment2',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com//deployment2');
|
||||
});
|
||||
|
||||
test('does not replace any placeholders when azure object is empty', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {},
|
||||
});
|
||||
expect(url).toBe('https://example.com//');
|
||||
});
|
||||
|
||||
test('returns baseURL as is when `azureOptions` object is not provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
});
|
||||
expect(url).toBe('https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}');
|
||||
});
|
||||
|
||||
test('returns baseURL as is when no placeholders are set', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/my_custom_instance/my_deployment',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/my_custom_instance/my_deployment');
|
||||
});
|
||||
|
||||
test('returns regular Azure OpenAI baseURL with placeholders set', () => {
|
||||
const baseURL =
|
||||
'https://${INSTANCE_NAME}.openai.azure.com/openai/deployments/${DEPLOYMENT_NAME}';
|
||||
const url = constructAzureURL({
|
||||
baseURL,
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://instance1.openai.azure.com/openai/deployments/deployment1');
|
||||
});
|
||||
});
|
||||
|
|
@ -1,7 +1,5 @@
|
|||
const loadYaml = require('./loadYaml');
|
||||
const axiosHelpers = require('./axios');
|
||||
const tokenHelpers = require('./tokens');
|
||||
const azureUtils = require('./azureUtils');
|
||||
const deriveBaseURL = require('./deriveBaseURL');
|
||||
const extractBaseURL = require('./extractBaseURL');
|
||||
const findMessageContent = require('./findMessageContent');
|
||||
|
|
@ -10,8 +8,6 @@ module.exports = {
|
|||
loadYaml,
|
||||
deriveBaseURL,
|
||||
extractBaseURL,
|
||||
...azureUtils,
|
||||
...axiosHelpers,
|
||||
...tokenHelpers,
|
||||
findMessageContent,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue