mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-03 06:40:20 +01:00
* chore: move database model methods to /packages/data-schemas * chore: add TypeScript ESLint rule to warn on unused variables * refactor: model imports to streamline access - Consolidated model imports across various files to improve code organization and reduce redundancy. - Updated imports for models such as Assistant, Message, Conversation, and others to a unified import path. - Adjusted middleware and service files to reflect the new import structure, ensuring functionality remains intact. - Enhanced test files to align with the new import paths, maintaining test coverage and integrity. * chore: migrate database models to packages/data-schemas and refactor all direct Mongoose Model usage outside of data-schemas * test: update agent model mocks in unit tests - Added `getAgent` mock to `client.test.js` to enhance test coverage for agent-related functionality. - Removed redundant `getAgent` and `getAgents` mocks from `openai.spec.js` and `responses.unit.spec.js` to streamline test setup and reduce duplication. - Ensured consistency in agent mock implementations across test files. * fix: update types in data-schemas * refactor: enhance type definitions in transaction and spending methods - Updated type definitions in `checkBalance.ts` to use specific request and response types. - Refined `spendTokens.ts` to utilize a new `SpendTxData` interface for better clarity and type safety. - Improved transaction handling in `transaction.ts` by introducing `TransactionResult` and `TxData` interfaces, ensuring consistent data structures across methods. - Adjusted unit tests in `transaction.spec.ts` to accommodate new type definitions and enhance robustness. * refactor: streamline model imports and enhance code organization - Consolidated model imports across various controllers and services to a unified import path, improving code clarity and reducing redundancy. - Updated multiple files to reflect the new import structure, ensuring all functionalities remain intact. - Enhanced overall code organization by removing duplicate import statements and optimizing the usage of model methods. * feat: implement loadAddedAgent and refactor agent loading logic - Introduced `loadAddedAgent` function to handle loading agents from added conversations, supporting multi-convo parallel execution. - Created a new `load.ts` file to encapsulate agent loading functionalities, including `loadEphemeralAgent` and `loadAgent`. - Updated the `index.ts` file to export the new `load` module instead of the deprecated `loadAgent`. - Enhanced type definitions and improved error handling in the agent loading process. - Adjusted unit tests to reflect changes in the agent loading structure and ensure comprehensive coverage. * refactor: enhance balance handling with new update interface - Introduced `IBalanceUpdate` interface to streamline balance update operations across the codebase. - Updated `upsertBalanceFields` method signatures in `balance.ts`, `transaction.ts`, and related tests to utilize the new interface for improved type safety. - Adjusted type imports in `balance.spec.ts` to include `IBalanceUpdate`, ensuring consistency in balance management functionalities. - Enhanced overall code clarity and maintainability by refining type definitions related to balance operations. * feat: add unit tests for loadAgent functionality and enhance agent loading logic - Introduced comprehensive unit tests for the `loadAgent` function, covering various scenarios including null and empty agent IDs, loading of ephemeral agents, and permission checks. - Enhanced the `initializeClient` function by moving `getConvoFiles` to the correct position in the database method exports, ensuring proper functionality. - Improved test coverage for agent loading, including handling of non-existent agents and user permissions. * chore: reorder memory method exports for consistency - Moved `deleteAllUserMemories` to the correct position in the exported memory methods, ensuring a consistent and logical order of method exports in `memory.ts`.
470 lines
15 KiB
JavaScript
470 lines
15 KiB
JavaScript
const path = require('path');
|
|
const sharp = require('sharp');
|
|
const { v4 } = require('uuid');
|
|
const { ProxyAgent } = require('undici');
|
|
const { GoogleGenAI } = require('@google/genai');
|
|
const { tool } = require('@langchain/core/tools');
|
|
const { logger } = require('@librechat/data-schemas');
|
|
const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
|
|
const {
|
|
geminiToolkit,
|
|
loadServiceKey,
|
|
getBalanceConfig,
|
|
getTransactionsConfig,
|
|
} = require('@librechat/api');
|
|
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
|
const { spendTokens, getFiles } = require('~/models');
|
|
|
|
/**
|
|
* Configure proxy support for Google APIs
|
|
* This wraps globalThis.fetch to add a proxy dispatcher only for googleapis.com URLs
|
|
* This is necessary because @google/genai SDK doesn't support custom fetch or httpOptions.dispatcher
|
|
*/
|
|
if (process.env.PROXY) {
|
|
const originalFetch = globalThis.fetch;
|
|
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
|
|
|
globalThis.fetch = function (url, options = {}) {
|
|
const urlString = url.toString();
|
|
if (urlString.includes('googleapis.com')) {
|
|
options = { ...options, dispatcher: proxyAgent };
|
|
}
|
|
return originalFetch.call(this, url, options);
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Get the default service key file path (consistent with main Google endpoint)
|
|
* @returns {string} - The default path to the service key file
|
|
*/
|
|
function getDefaultServiceKeyPath() {
|
|
return (
|
|
process.env.GOOGLE_SERVICE_KEY_FILE || path.join(process.cwd(), 'api', 'data', 'auth.json')
|
|
);
|
|
}
|
|
|
|
const displayMessage =
|
|
"Gemini displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.";
|
|
|
|
/**
|
|
* Replaces unwanted characters from the input string
|
|
* @param {string} inputString - The input string to process
|
|
* @returns {string} - The processed string
|
|
*/
|
|
function replaceUnwantedChars(inputString) {
|
|
return (
|
|
inputString
|
|
?.replace(/\r\n|\r|\n/g, ' ')
|
|
.replace(/"/g, '')
|
|
.trim() || ''
|
|
);
|
|
}
|
|
|
|
/**
|
|
* Convert image buffer to target format if needed
|
|
* @param {Buffer} inputBuffer - The input image buffer
|
|
* @param {string} targetFormat - The target format (png, jpeg, webp)
|
|
* @returns {Promise<{buffer: Buffer, format: string}>} - Converted buffer and format
|
|
*/
|
|
async function convertImageFormat(inputBuffer, targetFormat) {
|
|
const metadata = await sharp(inputBuffer).metadata();
|
|
const currentFormat = metadata.format;
|
|
|
|
// Normalize format names (jpg -> jpeg)
|
|
const normalizedTarget = targetFormat === 'jpg' ? 'jpeg' : targetFormat.toLowerCase();
|
|
const normalizedCurrent = currentFormat === 'jpg' ? 'jpeg' : currentFormat;
|
|
|
|
// If already in target format, return as-is
|
|
if (normalizedCurrent === normalizedTarget) {
|
|
return { buffer: inputBuffer, format: normalizedTarget };
|
|
}
|
|
|
|
// Convert to target format
|
|
const convertedBuffer = await sharp(inputBuffer).toFormat(normalizedTarget).toBuffer();
|
|
return { buffer: convertedBuffer, format: normalizedTarget };
|
|
}
|
|
|
|
/**
|
|
* Initialize Gemini client (supports both Gemini API and Vertex AI)
|
|
* Priority: API key (from options, resolved by loadAuthValues) > Vertex AI service account
|
|
* @param {Object} options - Initialization options
|
|
* @param {string} [options.GEMINI_API_KEY] - Gemini API key (resolved by loadAuthValues)
|
|
* @param {string} [options.GOOGLE_KEY] - Google API key (resolved by loadAuthValues)
|
|
* @returns {Promise<GoogleGenAI>} - The initialized client
|
|
*/
|
|
async function initializeGeminiClient(options = {}) {
|
|
const geminiKey = options.GEMINI_API_KEY;
|
|
if (geminiKey) {
|
|
logger.debug('[GeminiImageGen] Using Gemini API with GEMINI_API_KEY');
|
|
return new GoogleGenAI({ apiKey: geminiKey });
|
|
}
|
|
|
|
const googleKey = options.GOOGLE_KEY;
|
|
if (googleKey) {
|
|
logger.debug('[GeminiImageGen] Using Gemini API with GOOGLE_KEY');
|
|
return new GoogleGenAI({ apiKey: googleKey });
|
|
}
|
|
|
|
logger.debug('[GeminiImageGen] Using Vertex AI with service account');
|
|
const credentialsPath = getDefaultServiceKeyPath();
|
|
const serviceKey = await loadServiceKey(credentialsPath);
|
|
|
|
if (!serviceKey || !serviceKey.project_id) {
|
|
throw new Error(
|
|
'Gemini Image Generation requires one of: user-provided API key, GEMINI_API_KEY or GOOGLE_KEY env var, or a valid Google service account. ' +
|
|
`Service account file not found or invalid at: ${credentialsPath}`,
|
|
);
|
|
}
|
|
|
|
return new GoogleGenAI({
|
|
vertexai: true,
|
|
project: serviceKey.project_id,
|
|
location: process.env.GOOGLE_LOC || process.env.GOOGLE_CLOUD_LOCATION || 'global',
|
|
googleAuthOptions: { credentials: serviceKey },
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Convert image files to Gemini inline data format
|
|
* @param {Object} params - Parameters
|
|
* @returns {Promise<Array>} - Array of inline data objects
|
|
*/
|
|
async function convertImagesToInlineData({ imageFiles, image_ids, req, fileStrategy }) {
|
|
if (!image_ids || image_ids.length === 0) {
|
|
return [];
|
|
}
|
|
|
|
const streamMethods = {};
|
|
const requestFilesMap = Object.fromEntries(imageFiles.map((f) => [f.file_id, { ...f }]));
|
|
const orderedFiles = new Array(image_ids.length);
|
|
const idsToFetch = [];
|
|
const indexOfMissing = Object.create(null);
|
|
|
|
for (let i = 0; i < image_ids.length; i++) {
|
|
const id = image_ids[i];
|
|
const file = requestFilesMap[id];
|
|
if (file) {
|
|
orderedFiles[i] = file;
|
|
} else {
|
|
idsToFetch.push(id);
|
|
indexOfMissing[id] = i;
|
|
}
|
|
}
|
|
|
|
if (idsToFetch.length && req?.user?.id) {
|
|
const fetchedFiles = await getFiles(
|
|
{
|
|
user: req.user.id,
|
|
file_id: { $in: idsToFetch },
|
|
height: { $exists: true },
|
|
width: { $exists: true },
|
|
},
|
|
{},
|
|
{},
|
|
);
|
|
|
|
for (const file of fetchedFiles) {
|
|
requestFilesMap[file.file_id] = file;
|
|
orderedFiles[indexOfMissing[file.file_id]] = file;
|
|
}
|
|
}
|
|
|
|
const inlineDataArray = [];
|
|
for (const imageFile of orderedFiles) {
|
|
if (!imageFile) continue;
|
|
|
|
try {
|
|
const source = imageFile.source || fileStrategy;
|
|
if (!source) continue;
|
|
|
|
let getDownloadStream = streamMethods[source];
|
|
if (!getDownloadStream) {
|
|
({ getDownloadStream } = getStrategyFunctions(source));
|
|
streamMethods[source] = getDownloadStream;
|
|
}
|
|
if (!getDownloadStream) continue;
|
|
|
|
const stream = await getDownloadStream(req, imageFile.filepath);
|
|
if (!stream) continue;
|
|
|
|
const chunks = [];
|
|
for await (const chunk of stream) {
|
|
chunks.push(chunk);
|
|
}
|
|
const buffer = Buffer.concat(chunks);
|
|
const base64Data = buffer.toString('base64');
|
|
const mimeType = imageFile.type || 'image/png';
|
|
|
|
inlineDataArray.push({
|
|
inlineData: { mimeType, data: base64Data },
|
|
});
|
|
} catch (error) {
|
|
logger.error('[GeminiImageGen] Error processing image:', imageFile.file_id, error);
|
|
}
|
|
}
|
|
|
|
return inlineDataArray;
|
|
}
|
|
|
|
/**
|
|
* Check for safety blocks in API response
|
|
* @param {Object} response - The API response
|
|
* @returns {Object|null} - Safety block info or null
|
|
*/
|
|
function checkForSafetyBlock(response) {
|
|
if (!response?.candidates?.length) {
|
|
return { reason: 'NO_CANDIDATES', message: 'No candidates returned' };
|
|
}
|
|
|
|
const candidate = response.candidates[0];
|
|
const finishReason = candidate.finishReason;
|
|
|
|
if (finishReason === 'SAFETY' || finishReason === 'PROHIBITED_CONTENT') {
|
|
return { reason: finishReason, message: 'Content blocked by safety filters' };
|
|
}
|
|
|
|
if (finishReason === 'RECITATION') {
|
|
return { reason: finishReason, message: 'Content blocked due to recitation concerns' };
|
|
}
|
|
|
|
if (candidate.safetyRatings) {
|
|
for (const rating of candidate.safetyRatings) {
|
|
if (rating.probability === 'HIGH' || rating.blocked === true) {
|
|
return {
|
|
reason: 'SAFETY_RATING',
|
|
message: `Blocked due to ${rating.category}`,
|
|
category: rating.category,
|
|
};
|
|
}
|
|
}
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
/**
|
|
* Record token usage for balance tracking
|
|
* @param {Object} params - Parameters
|
|
* @param {Object} params.usageMetadata - The usage metadata from API response
|
|
* @param {Object} params.req - The request object
|
|
* @param {string} params.userId - The user ID
|
|
* @param {string} params.conversationId - The conversation ID
|
|
* @param {string} params.model - The model name
|
|
*/
|
|
async function recordTokenUsage({ usageMetadata, req, userId, conversationId, model }) {
|
|
if (!usageMetadata) {
|
|
logger.debug('[GeminiImageGen] No usage metadata available for balance tracking');
|
|
return;
|
|
}
|
|
|
|
const appConfig = req?.config;
|
|
const balance = getBalanceConfig(appConfig);
|
|
const transactions = getTransactionsConfig(appConfig);
|
|
|
|
// Skip if neither balance nor transactions are enabled
|
|
if (!balance?.enabled && transactions?.enabled === false) {
|
|
return;
|
|
}
|
|
|
|
const promptTokens = usageMetadata.prompt_token_count || usageMetadata.promptTokenCount || 0;
|
|
const completionTokens =
|
|
usageMetadata.candidates_token_count || usageMetadata.candidatesTokenCount || 0;
|
|
|
|
if (promptTokens === 0 && completionTokens === 0) {
|
|
logger.debug('[GeminiImageGen] No tokens to record');
|
|
return;
|
|
}
|
|
|
|
logger.debug('[GeminiImageGen] Recording token usage:', {
|
|
promptTokens,
|
|
completionTokens,
|
|
model,
|
|
conversationId,
|
|
});
|
|
|
|
try {
|
|
await spendTokens(
|
|
{
|
|
user: userId,
|
|
model,
|
|
conversationId,
|
|
context: 'image_generation',
|
|
balance,
|
|
transactions,
|
|
},
|
|
{
|
|
promptTokens,
|
|
completionTokens,
|
|
},
|
|
);
|
|
} catch (error) {
|
|
logger.error('[GeminiImageGen] Error recording token usage:', error);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Creates Gemini Image Generation tool
|
|
* @param {Object} fields - Configuration fields
|
|
* @returns {ReturnType<tool>} - The image generation tool
|
|
*/
|
|
function createGeminiImageTool(fields = {}) {
|
|
const override = fields.override ?? false;
|
|
|
|
if (!override && !fields.isAgent) {
|
|
throw new Error('This tool is only available for agents.');
|
|
}
|
|
|
|
const { req, imageFiles = [], userId, fileStrategy, GEMINI_API_KEY, GOOGLE_KEY } = fields;
|
|
|
|
const imageOutputType = fields.imageOutputType || EImageOutputType.PNG;
|
|
|
|
const geminiImageGenTool = tool(
|
|
async ({ prompt, image_ids, aspectRatio, imageSize }, runnableConfig) => {
|
|
if (!prompt) {
|
|
throw new Error('Missing required field: prompt');
|
|
}
|
|
|
|
logger.debug('[GeminiImageGen] Generating image', { aspectRatio, imageSize });
|
|
|
|
let ai;
|
|
try {
|
|
ai = await initializeGeminiClient({
|
|
GEMINI_API_KEY,
|
|
GOOGLE_KEY,
|
|
});
|
|
} catch (error) {
|
|
logger.error('[GeminiImageGen] Failed to initialize client:', error);
|
|
return [
|
|
[{ type: ContentTypes.TEXT, text: `Failed to initialize Gemini: ${error.message}` }],
|
|
{ content: [], file_ids: [] },
|
|
];
|
|
}
|
|
|
|
const contents = [{ text: replaceUnwantedChars(prompt) }];
|
|
|
|
if (image_ids?.length > 0) {
|
|
const contextImages = await convertImagesToInlineData({
|
|
imageFiles,
|
|
image_ids,
|
|
req,
|
|
fileStrategy,
|
|
});
|
|
contents.push(...contextImages);
|
|
logger.debug('[GeminiImageGen] Added', contextImages.length, 'context images');
|
|
}
|
|
|
|
let apiResponse;
|
|
const geminiModel = process.env.GEMINI_IMAGE_MODEL || 'gemini-2.5-flash-image';
|
|
const config = {
|
|
responseModalities: ['TEXT', 'IMAGE'],
|
|
};
|
|
|
|
const supportsImageSize = !geminiModel.includes('gemini-2.5-flash-image');
|
|
if (aspectRatio || (imageSize && supportsImageSize)) {
|
|
config.imageConfig = {};
|
|
if (aspectRatio) {
|
|
config.imageConfig.aspectRatio = aspectRatio;
|
|
}
|
|
if (imageSize && supportsImageSize) {
|
|
config.imageConfig.imageSize = imageSize;
|
|
}
|
|
}
|
|
|
|
let derivedSignal = null;
|
|
let abortHandler = null;
|
|
|
|
if (runnableConfig?.signal) {
|
|
derivedSignal = AbortSignal.any([runnableConfig.signal]);
|
|
abortHandler = () => logger.debug('[GeminiImageGen] Image generation aborted');
|
|
derivedSignal.addEventListener('abort', abortHandler, { once: true });
|
|
config.abortSignal = derivedSignal;
|
|
}
|
|
|
|
try {
|
|
apiResponse = await ai.models.generateContent({
|
|
model: geminiModel,
|
|
contents,
|
|
config,
|
|
});
|
|
} catch (error) {
|
|
logger.error('[GeminiImageGen] API error:', error);
|
|
return [
|
|
[{ type: ContentTypes.TEXT, text: `Image generation failed: ${error.message}` }],
|
|
{ content: [], file_ids: [] },
|
|
];
|
|
} finally {
|
|
if (abortHandler && derivedSignal) {
|
|
derivedSignal.removeEventListener('abort', abortHandler);
|
|
}
|
|
}
|
|
|
|
const safetyBlock = checkForSafetyBlock(apiResponse);
|
|
if (safetyBlock) {
|
|
logger.warn('[GeminiImageGen] Safety block:', safetyBlock);
|
|
const errorMsg = 'Image blocked by content safety filters. Please try different content.';
|
|
return [[{ type: ContentTypes.TEXT, text: errorMsg }], { content: [], file_ids: [] }];
|
|
}
|
|
|
|
const rawImageData = apiResponse.candidates?.[0]?.content?.parts?.find((p) => p.inlineData)
|
|
?.inlineData?.data;
|
|
|
|
if (!rawImageData) {
|
|
logger.warn('[GeminiImageGen] No image data in response');
|
|
return [
|
|
[{ type: ContentTypes.TEXT, text: 'No image was generated. Please try again.' }],
|
|
{ content: [], file_ids: [] },
|
|
];
|
|
}
|
|
|
|
const rawBuffer = Buffer.from(rawImageData, 'base64');
|
|
const { buffer: convertedBuffer, format: outputFormat } = await convertImageFormat(
|
|
rawBuffer,
|
|
imageOutputType,
|
|
);
|
|
const imageData = convertedBuffer.toString('base64');
|
|
const mimeType = outputFormat === 'jpeg' ? 'image/jpeg' : `image/${outputFormat}`;
|
|
|
|
const dataUrl = `data:${mimeType};base64,${imageData}`;
|
|
const file_ids = [v4()];
|
|
const content = [
|
|
{
|
|
type: ContentTypes.IMAGE_URL,
|
|
image_url: { url: dataUrl },
|
|
},
|
|
];
|
|
|
|
const textResponse = [
|
|
{
|
|
type: ContentTypes.TEXT,
|
|
text:
|
|
displayMessage +
|
|
`\n\ngenerated_image_id: "${file_ids[0]}"` +
|
|
(image_ids?.length > 0 ? `\nreferenced_image_ids: ["${image_ids.join('", "')}"]` : ''),
|
|
},
|
|
];
|
|
|
|
const conversationId = runnableConfig?.configurable?.thread_id;
|
|
recordTokenUsage({
|
|
usageMetadata: apiResponse.usageMetadata,
|
|
req,
|
|
userId,
|
|
conversationId,
|
|
model: geminiModel,
|
|
}).catch((error) => {
|
|
logger.error('[GeminiImageGen] Failed to record token usage:', error);
|
|
});
|
|
|
|
return [textResponse, { content, file_ids }];
|
|
},
|
|
{
|
|
...geminiToolkit.gemini_image_gen,
|
|
responseFormat: 'content_and_artifact',
|
|
},
|
|
);
|
|
|
|
return geminiImageGenTool;
|
|
}
|
|
|
|
// Export both for compatibility
|
|
module.exports = createGeminiImageTool;
|
|
module.exports.createGeminiImageTool = createGeminiImageTool;
|