mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-16 16:30:15 +01:00
🧑💼 feat: Add Agent Model Validation (#8995)
* fix: Update logger import to use data-schemas module * feat: agent model validation * fix: Remove invalid error messages from translation file
This commit is contained in:
parent
8cefa566da
commit
c5ca621efd
9 changed files with 189 additions and 54 deletions
|
|
@ -5,6 +5,7 @@ const { logger } = require('~/config');
|
|||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @returns {Promise<TModelsConfig>} The models config.
|
||||
*/
|
||||
const getModelsConfig = async (req) => {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ const validateModel = async (req, res, next) => {
|
|||
return next();
|
||||
}
|
||||
|
||||
const { ILLEGAL_MODEL_REQ_SCORE: score = 5 } = process.env ?? {};
|
||||
const { ILLEGAL_MODEL_REQ_SCORE: score = 1 } = process.env ?? {};
|
||||
|
||||
const type = ViolationTypes.ILLEGAL_MODEL_REQUEST;
|
||||
const errorMessage = {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
|
||||
const {
|
||||
getAnthropicModels,
|
||||
getBedrockModels,
|
||||
getOpenAIModels,
|
||||
getGoogleModels,
|
||||
getBedrockModels,
|
||||
getAnthropicModels,
|
||||
} = require('~/server/services/ModelService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Loads the default models for the application.
|
||||
|
|
@ -16,58 +15,42 @@ const { logger } = require('~/config');
|
|||
*/
|
||||
async function loadDefaultModels(req) {
|
||||
try {
|
||||
const [
|
||||
openAI,
|
||||
anthropic,
|
||||
azureOpenAI,
|
||||
gptPlugins,
|
||||
assistants,
|
||||
azureAssistants,
|
||||
google,
|
||||
bedrock,
|
||||
] = await Promise.all([
|
||||
getOpenAIModels({ user: req.user.id }).catch((error) => {
|
||||
logger.error('Error fetching OpenAI models:', error);
|
||||
return [];
|
||||
}),
|
||||
getAnthropicModels({ user: req.user.id }).catch((error) => {
|
||||
logger.error('Error fetching Anthropic models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ user: req.user.id, azure: true }).catch((error) => {
|
||||
logger.error('Error fetching Azure OpenAI models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ user: req.user.id, azure: useAzurePlugins, plugins: true }).catch(
|
||||
(error) => {
|
||||
logger.error('Error fetching Plugin models:', error);
|
||||
const [openAI, anthropic, azureOpenAI, assistants, azureAssistants, google, bedrock] =
|
||||
await Promise.all([
|
||||
getOpenAIModels({ user: req.user.id }).catch((error) => {
|
||||
logger.error('Error fetching OpenAI models:', error);
|
||||
return [];
|
||||
},
|
||||
),
|
||||
getOpenAIModels({ assistants: true }).catch((error) => {
|
||||
logger.error('Error fetching OpenAI Assistants API models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ azureAssistants: true }).catch((error) => {
|
||||
logger.error('Error fetching Azure OpenAI Assistants API models:', error);
|
||||
return [];
|
||||
}),
|
||||
Promise.resolve(getGoogleModels()).catch((error) => {
|
||||
logger.error('Error getting Google models:', error);
|
||||
return [];
|
||||
}),
|
||||
Promise.resolve(getBedrockModels()).catch((error) => {
|
||||
logger.error('Error getting Bedrock models:', error);
|
||||
return [];
|
||||
}),
|
||||
]);
|
||||
}),
|
||||
getAnthropicModels({ user: req.user.id }).catch((error) => {
|
||||
logger.error('Error fetching Anthropic models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ user: req.user.id, azure: true }).catch((error) => {
|
||||
logger.error('Error fetching Azure OpenAI models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ assistants: true }).catch((error) => {
|
||||
logger.error('Error fetching OpenAI Assistants API models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ azureAssistants: true }).catch((error) => {
|
||||
logger.error('Error fetching Azure OpenAI Assistants API models:', error);
|
||||
return [];
|
||||
}),
|
||||
Promise.resolve(getGoogleModels()).catch((error) => {
|
||||
logger.error('Error getting Google models:', error);
|
||||
return [];
|
||||
}),
|
||||
Promise.resolve(getBedrockModels()).catch((error) => {
|
||||
logger.error('Error getting Bedrock models:', error);
|
||||
return [];
|
||||
}),
|
||||
]);
|
||||
|
||||
return {
|
||||
[EModelEndpoint.openAI]: openAI,
|
||||
[EModelEndpoint.agents]: openAI,
|
||||
[EModelEndpoint.google]: google,
|
||||
[EModelEndpoint.anthropic]: anthropic,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins,
|
||||
[EModelEndpoint.azureOpenAI]: azureOpenAI,
|
||||
[EModelEndpoint.assistants]: assistants,
|
||||
[EModelEndpoint.azureAssistants]: azureAssistants,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isAgentsEndpoint, removeNullishValues, Constants } = require('librechat-data-provider');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
||||
const { spec, iconURL, agent_id, instructions, ...model_parameters } = parsedBody;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { validateAgentModel } = require('@librechat/api');
|
||||
const { createContentAggregator } = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
|
|
@ -11,10 +12,12 @@ const {
|
|||
getDefaultHandlers,
|
||||
} = require('~/server/controllers/agents/callbacks');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { logViolation } = require('~/cache');
|
||||
|
||||
function createToolLoader() {
|
||||
/**
|
||||
|
|
@ -72,6 +75,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
throw new Error('Agent not found');
|
||||
}
|
||||
|
||||
const modelsConfig = await getModelsConfig(req);
|
||||
const validationResult = await validateAgentModel({
|
||||
req,
|
||||
res,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
agent: primaryAgent,
|
||||
});
|
||||
|
||||
if (!validationResult.isValid) {
|
||||
throw new Error(validationResult.error?.message);
|
||||
}
|
||||
|
||||
const agentConfigs = new Map();
|
||||
/** @type {Set<string>} */
|
||||
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
|
@ -101,6 +117,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
if (!agent) {
|
||||
throw new Error(`Agent ${agentId} not found`);
|
||||
}
|
||||
|
||||
const validationResult = await validateAgentModel({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
});
|
||||
|
||||
if (!validationResult.isValid) {
|
||||
throw new Error(validationResult.error?.message);
|
||||
}
|
||||
|
||||
const config = await initializeAgent({
|
||||
req,
|
||||
res,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
// file deepcode ignore HardcodedNonCryptoSecret: No hardcoded secrets
|
||||
import { ViolationTypes, ErrorTypes, alternateName } from 'librechat-data-provider';
|
||||
import type { TOpenAIMessage } from 'librechat-data-provider';
|
||||
import type { LocalizeFunction } from '~/common';
|
||||
import { formatJSON, extractJson, isJson } from '~/utils/json';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
|
@ -25,7 +24,7 @@ type TTokenBalance = {
|
|||
prev_count: number;
|
||||
violation_count: number;
|
||||
date: Date;
|
||||
generations?: TOpenAIMessage[];
|
||||
generations?: unknown[];
|
||||
};
|
||||
|
||||
type TExpiredKey = {
|
||||
|
|
@ -44,6 +43,17 @@ const errorMessages = {
|
|||
[ErrorTypes.NO_BASE_URL]: 'com_error_no_base_url',
|
||||
[ErrorTypes.INVALID_ACTION]: `com_error_${ErrorTypes.INVALID_ACTION}`,
|
||||
[ErrorTypes.INVALID_REQUEST]: `com_error_${ErrorTypes.INVALID_REQUEST}`,
|
||||
[ErrorTypes.MISSING_MODEL]: (json: TGenericError, localize: LocalizeFunction) => {
|
||||
const { info: endpoint } = json;
|
||||
const provider = (alternateName[endpoint ?? ''] as string | undefined) ?? endpoint ?? 'unknown';
|
||||
return localize('com_error_missing_model', { 0: provider });
|
||||
},
|
||||
[ErrorTypes.MODELS_NOT_LOADED]: 'com_error_models_not_loaded',
|
||||
[ErrorTypes.ENDPOINT_MODELS_NOT_LOADED]: (json: TGenericError, localize: LocalizeFunction) => {
|
||||
const { info: endpoint } = json;
|
||||
const provider = (alternateName[endpoint ?? ''] as string | undefined) ?? endpoint ?? 'unknown';
|
||||
return localize('com_error_endpoint_models_not_loaded', { 0: provider });
|
||||
},
|
||||
[ErrorTypes.NO_SYSTEM_MESSAGES]: `com_error_${ErrorTypes.NO_SYSTEM_MESSAGES}`,
|
||||
[ErrorTypes.EXPIRED_USER_KEY]: (json: TExpiredKey, localize: LocalizeFunction) => {
|
||||
const { expiredAt, endpoint } = json;
|
||||
|
|
@ -65,6 +75,12 @@ const errorMessages = {
|
|||
[ErrorTypes.GOOGLE_TOOL_CONFLICT]: 'com_error_google_tool_conflict',
|
||||
[ViolationTypes.BAN]:
|
||||
'Your account has been temporarily banned due to violations of our service.',
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: (json: TGenericError, localize: LocalizeFunction) => {
|
||||
const { info } = json;
|
||||
const [endpoint, model = 'unknown'] = info?.split('|') ?? [];
|
||||
const provider = (alternateName[endpoint ?? ''] as string | undefined) ?? endpoint ?? 'unknown';
|
||||
return localize('com_error_illegal_model_request', { 0: model, 1: provider });
|
||||
},
|
||||
invalid_api_key:
|
||||
'Invalid API key. Please check your API key and try again. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.',
|
||||
insufficient_quota:
|
||||
|
|
|
|||
|
|
@ -297,10 +297,14 @@
|
|||
"com_error_files_upload_canceled": "The file upload request was canceled. Note: the file upload may still be processing and will need to be manually deleted.",
|
||||
"com_error_files_validation": "An error occurred while validating the file.",
|
||||
"com_error_google_tool_conflict": "Usage of built-in Google tools are not supported with external tools. Please disable either the built-in tools or the external tools.",
|
||||
"com_error_endpoint_models_not_loaded": "Models for {{0}} could not be loaded. Please refresh the page and try again.",
|
||||
"com_error_heic_conversion": "Failed to convert HEIC image to JPEG. Please try converting the image manually or use a different format.",
|
||||
"com_error_illegal_model_request": "The model \"{{0}}\" is not available for {{1}}. Please select a different model.",
|
||||
"com_error_input_length": "The latest message token count is too long, exceeding the token limit, or your token limit parameters are misconfigured, adversely affecting the context window. More info: {{0}}. Please shorten your message, adjust the max context size from the conversation parameters, or fork the conversation to continue.",
|
||||
"com_error_invalid_agent_provider": "The \"{{0}}\" provider is not available for use with Agents. Please go to your agent's settings and select a currently available provider.",
|
||||
"com_error_invalid_user_key": "Invalid key provided. Please provide a valid key and try again.",
|
||||
"com_error_missing_model": "No model selected for {{0}}. Please select a model and try again.",
|
||||
"com_error_models_not_loaded": "Models configuration could not be loaded. Please refresh the page and try again.",
|
||||
"com_error_moderation": "It appears that the content submitted has been flagged by our moderation system for not aligning with our community guidelines. We're unable to proceed with this specific topic. If you have any other questions or topics you'd like to explore, please edit your message, or create a new conversation.",
|
||||
"com_error_no_base_url": "No base URL found. Please provide one and try again.",
|
||||
"com_error_no_user_key": "No key found. Please provide a key and try again.",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
import { z } from 'zod';
|
||||
import { ViolationTypes, ErrorTypes } from 'librechat-data-provider';
|
||||
import type { Agent, TModelsConfig } from 'librechat-data-provider';
|
||||
import type { Request, Response } from 'express';
|
||||
|
||||
/** Avatar schema shared between create and update */
|
||||
export const agentAvatarSchema = z.object({
|
||||
|
|
@ -59,3 +62,90 @@ export const agentUpdateSchema = agentBaseSchema.extend({
|
|||
removeProjectIds: z.array(z.string()).optional(),
|
||||
isCollaborative: z.boolean().optional(),
|
||||
});
|
||||
|
||||
interface ValidateAgentModelParams {
|
||||
req: Request;
|
||||
res: Response;
|
||||
agent: Agent;
|
||||
modelsConfig: TModelsConfig;
|
||||
logViolation: (
|
||||
req: Request,
|
||||
res: Response,
|
||||
type: string,
|
||||
errorMessage: Record<string, unknown>,
|
||||
score?: number | string,
|
||||
) => Promise<void>;
|
||||
}
|
||||
|
||||
interface ValidateAgentModelResult {
|
||||
isValid: boolean;
|
||||
error?: {
|
||||
message: string;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates an agent's model against the available models configuration.
|
||||
* This is a non-middleware version of validateModel that can be used
|
||||
* in service initialization flows.
|
||||
*
|
||||
* @param params - Validation parameters
|
||||
* @returns Object indicating whether the model is valid and any error details
|
||||
*/
|
||||
export async function validateAgentModel(
|
||||
params: ValidateAgentModelParams,
|
||||
): Promise<ValidateAgentModelResult> {
|
||||
const { req, res, agent, modelsConfig, logViolation } = params;
|
||||
const { model, provider: endpoint } = agent;
|
||||
|
||||
if (!model) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: {
|
||||
message: `{ "type": "${ErrorTypes.MISSING_MODEL}", "info": "${endpoint}" }`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if (!modelsConfig) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: {
|
||||
message: `{ "type": "${ErrorTypes.MODELS_NOT_LOADED}" }`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const availableModels = modelsConfig[endpoint];
|
||||
if (!availableModels) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: {
|
||||
message: `{ "type": "${ErrorTypes.ENDPOINT_MODELS_NOT_LOADED}", "info": "${endpoint}" }`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const validModel = !!availableModels.find((availableModel) => availableModel === model);
|
||||
|
||||
if (validModel) {
|
||||
return { isValid: true };
|
||||
}
|
||||
|
||||
const { ILLEGAL_MODEL_REQ_SCORE: score = 1 } = process.env ?? {};
|
||||
const type = ViolationTypes.ILLEGAL_MODEL_REQUEST;
|
||||
const errorMessage = {
|
||||
type,
|
||||
model,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
|
||||
return {
|
||||
isValid: false,
|
||||
error: {
|
||||
message: `{ "type": "${ViolationTypes.ILLEGAL_MODEL_REQUEST}", "info": "${endpoint}|${model}" }`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1347,6 +1347,18 @@ export enum ErrorTypes {
|
|||
* Invalid Agent Provider (excluded by Admin)
|
||||
*/
|
||||
INVALID_AGENT_PROVIDER = 'invalid_agent_provider',
|
||||
/**
|
||||
* Missing model selection
|
||||
*/
|
||||
MISSING_MODEL = 'missing_model',
|
||||
/**
|
||||
* Models configuration not loaded
|
||||
*/
|
||||
MODELS_NOT_LOADED = 'models_not_loaded',
|
||||
/**
|
||||
* Endpoint models not loaded
|
||||
*/
|
||||
ENDPOINT_MODELS_NOT_LOADED = 'endpoint_models_not_loaded',
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue