🗃️ feat: General File Support for OpenAI, Azure, Custom, Anthropic and Google (RAG) (#2143)

* refactor: re-purpose `resendImages` as `resendFiles`

* refactor: re-purpose `resendImages` as `resendFiles`

* feat: upload general files

* feat: embed file during upload

* feat: delete file embeddings on file deletion

* chore(fileConfig): add epub+zip type

* feat(encodeAndFormat): handle non-image files

* feat(createContextHandlers): build context prompt from file attachments and successful RAG

* fix: prevent non-temp files as well as embedded files to be deleted on new conversation

* fix: remove temp_file_id on usage, prevent non-temp files as well as embedded files to be deleted on new conversation

* fix: prevent non-temp files as well as embedded files to be deleted on new conversation

* feat(OpenAI/Anthropic/Google): basic RAG support

* fix: delete `resendFiles` only when true (Default)

* refactor(RAG): update endpoints and pass JWT

* fix(resendFiles): default values

* fix(context/processFile): query unique ids only

* feat: rag-api.yaml

* feat: file upload improved ux for longer uploads

* chore: await embed call and catch embedding errors

* refactor: store augmentedPrompt in Client

* refactor(processFileUpload): throw error if not assistant file upload

* fix(useFileHandling): handle markdown empty mimetype issue

* chore: necessary compose file changes
This commit is contained in:
Danny Avila 2024-03-19 20:54:30 -04:00 committed by GitHub
parent af347cccde
commit f7761df52c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 683 additions and 261 deletions

View file

@ -6,10 +6,9 @@ const {
validateVisionModel,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { formatMessage, createContextHandlers } = require('./prompts');
const spendTokens = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
const { formatMessage } = require('./prompts');
const { getFiles } = require('~/models/File');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
@ -67,7 +66,7 @@ class AnthropicClient extends BaseClient {
this.useMessages = this.isClaude3 || !!this.options.attachments;
this.defaultVisionModel = this.options.visionModel ?? 'claude-3-sonnet-20240229';
this.checkVisionRequest(this.options.attachments);
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
this.maxContextTokens =
getModelMaxTokens(this.modelOptions.model, EModelEndpoint.anthropic) ?? 100000;
@ -134,14 +133,19 @@ class AnthropicClient extends BaseClient {
* - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request.
* - Sets `this.isVisionModel` to `true` if vision request.
* - Deletes `this.modelOptions.stop` if vision request.
* @param {Array<Promise<MongoFile[]> | MongoFile[]> | Record<string, MongoFile[]>} attachments
* @param {MongoFile[]} attachments
*/
checkVisionRequest(attachments) {
const availableModels = this.options.modelsConfig?.[EModelEndpoint.anthropic];
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
const visionModelAvailable = availableModels?.includes(this.defaultVisionModel);
if (attachments && visionModelAvailable && !this.isVisionModel) {
if (
attachments &&
attachments.some((file) => file?.type && file?.type?.includes('image')) &&
visionModelAvailable &&
!this.isVisionModel
) {
this.modelOptions.model = this.defaultVisionModel;
this.isVisionModel = true;
}
@ -168,7 +172,7 @@ class AnthropicClient extends BaseClient {
attachments,
EModelEndpoint.anthropic,
);
message.image_urls = image_urls;
message.image_urls = image_urls.length ? image_urls : undefined;
return files;
}
@ -186,54 +190,6 @@ class AnthropicClient extends BaseClient {
);
}
/**
*
* @param {TMessage[]} _messages
* @returns {TMessage[]}
*/
async addPreviousAttachments(_messages) {
if (!this.options.resendImages) {
return _messages;
}
/**
*
* @param {TMessage} message
*/
const processMessage = async (message) => {
if (!this.message_file_map) {
/** @type {Record<string, MongoFile[]> */
this.message_file_map = {};
}
const fileIds = message.files.map((file) => file.file_id);
const files = await getFiles({
file_id: { $in: fileIds },
});
await this.addImageURLs(message, files);
this.message_file_map[message.messageId] = files;
return message;
};
const promises = [];
for (const message of _messages) {
if (!message.files) {
promises.push(message);
continue;
}
promises.push(processMessage(message));
}
const messages = await Promise.all(promises);
this.checkVisionRequest(this.message_file_map);
return messages;
}
async buildMessages(messages, parentMessageId) {
const orderedMessages = this.constructor.getMessagesForConversation({
messages,
@ -242,12 +198,13 @@ class AnthropicClient extends BaseClient {
logger.debug('[AnthropicClient] orderedMessages', { orderedMessages, parentMessageId });
if (!this.isVisionModel && this.options.attachments) {
throw new Error('Attachments are only supported with the Claude 3 family of models');
} else if (this.options.attachments) {
const attachments = (await this.options.attachments).filter((file) =>
file.type.includes('image'),
);
if (this.options.attachments) {
const attachments = await this.options.attachments;
const images = attachments.filter((file) => file.type.includes('image'));
if (images.length && !this.isVisionModel) {
throw new Error('Images are only supported with the Claude 3 family of models');
}
const latestMessage = orderedMessages[orderedMessages.length - 1];
@ -264,6 +221,13 @@ class AnthropicClient extends BaseClient {
this.options.attachments = files;
}
if (this.message_file_map) {
this.contextHandlers = createContextHandlers(
this.options.req,
orderedMessages[orderedMessages.length - 1].text,
);
}
const formattedMessages = orderedMessages.map((message, i) => {
const formattedMessage = this.useMessages
? formatMessage({
@ -285,6 +249,11 @@ class AnthropicClient extends BaseClient {
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
orderedMessages[i].tokenCount += this.calculateImageTokenCost({
width: file.width,
height: file.height,
@ -296,6 +265,11 @@ class AnthropicClient extends BaseClient {
return formattedMessage;
});
if (this.contextHandlers) {
this.augmentedPrompt = await this.contextHandlers.createContext();
this.options.promptPrefix = this.augmentedPrompt + (this.options.promptPrefix ?? '');
}
let { context: messagesInWindow, remainingContextTokens } =
await this.getMessagesWithinTokenLimit(formattedMessages);
@ -389,7 +363,7 @@ class AnthropicClient extends BaseClient {
let isEdited = lastAuthor === this.assistantLabel;
const promptSuffix = isEdited ? '' : `${promptPrefix}${this.assistantLabel}\n`;
let currentTokenCount =
isEdited || this.useMEssages
isEdited || this.useMessages
? this.getTokenCount(promptPrefix)
: this.getTokenCount(promptSuffix);
@ -663,6 +637,7 @@ class AnthropicClient extends BaseClient {
return {
promptPrefix: this.options.promptPrefix,
modelLabel: this.options.modelLabel,
resendFiles: this.options.resendFiles,
...this.modelOptions,
};
}

View file

@ -3,6 +3,7 @@ const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
const TextStream = require('./TextStream');
const { logger } = require('~/config');
@ -46,10 +47,6 @@ class BaseClient {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', response);
}
async addPreviousAttachments(messages) {
return messages;
}
async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', {
promptTokens,
@ -683,6 +680,54 @@ class BaseClient {
return await this.sendCompletion(payload, opts);
}
/**
*
* @param {TMessage[]} _messages
* @returns {Promise<TMessage[]>}
*/
async addPreviousAttachments(_messages) {
if (!this.options.resendFiles) {
return _messages;
}
/**
*
* @param {TMessage} message
*/
const processMessage = async (message) => {
if (!this.message_file_map) {
/** @type {Record<string, MongoFile[]> */
this.message_file_map = {};
}
const fileIds = message.files.map((file) => file.file_id);
const files = await getFiles({
file_id: { $in: fileIds },
});
await this.addImageURLs(message, files);
this.message_file_map[message.messageId] = files;
return message;
};
const promises = [];
for (const message of _messages) {
if (!message.files) {
promises.push(message);
continue;
}
promises.push(processMessage(message));
}
const messages = await Promise.all(promises);
this.checkVisionRequest(Object.values(this.message_file_map ?? {}).flat());
return messages;
}
}
module.exports = BaseClient;

View file

@ -13,8 +13,8 @@ const {
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { formatMessage, createContextHandlers } = require('./prompts');
const { getModelMaxTokens } = require('~/utils');
const { formatMessage } = require('./prompts');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
@ -124,24 +124,7 @@ class GoogleClient extends BaseClient {
// stop: modelOptions.stop // no stop method for now
};
/* Validation vision request */
this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision';
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
if (
this.options.attachments &&
availableModels?.includes(this.defaultVisionModel) &&
!this.isVisionModel
) {
this.modelOptions.model = this.defaultVisionModel;
this.isVisionModel = true;
}
if (this.isVisionModel && !this.options.attachments) {
this.modelOptions.model = 'gemini-pro';
this.isVisionModel = false;
}
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
@ -230,6 +213,33 @@ class GoogleClient extends BaseClient {
return this;
}
/**
*
* Checks if the model is a vision model based on request attachments and sets the appropriate options:
* @param {MongoFile[]} attachments
*/
checkVisionRequest(attachments) {
/* Validation vision request */
this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision';
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
if (
attachments &&
attachments.some((file) => file?.type && file?.type?.includes('image')) &&
availableModels?.includes(this.defaultVisionModel) &&
!this.isVisionModel
) {
this.modelOptions.model = this.defaultVisionModel;
this.isVisionModel = true;
}
if (this.isVisionModel && !attachments) {
this.modelOptions.model = 'gemini-pro';
this.isVisionModel = false;
}
}
formatMessages() {
return ((message) => ({
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
@ -237,18 +247,45 @@ class GoogleClient extends BaseClient {
})).bind(this);
}
async buildVisionMessages(messages = [], parentMessageId) {
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
const attachments = await this.options.attachments;
/**
*
* Adds image URLs to the message object and returns the files
*
* @param {TMessage[]} messages
* @param {MongoFile[]} files
* @returns {Promise<MongoFile[]>}
*/
async addImageURLs(message, attachments) {
const { files, image_urls } = await encodeAndFormat(
this.options.req,
attachments.filter((file) => file.type.includes('image')),
attachments,
EModelEndpoint.google,
);
message.image_urls = image_urls.length ? image_urls : undefined;
return files;
}
async buildVisionMessages(messages = [], parentMessageId) {
const attachments = await this.options.attachments;
const latestMessage = { ...messages[messages.length - 1] };
this.contextHandlers = createContextHandlers(this.options.req, latestMessage.text);
if (this.contextHandlers) {
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
}
this.augmentedPrompt = await this.contextHandlers.createContext();
this.options.promptPrefix = this.augmentedPrompt + this.options.promptPrefix;
}
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
const files = await this.addImageURLs(latestMessage, attachments);
latestMessage.image_urls = image_urls;
this.options.attachments = files;
latestMessage.text = prompt;
@ -275,7 +312,7 @@ class GoogleClient extends BaseClient {
);
}
if (this.options.attachments) {
if (this.options.attachments && this.isGenerativeModel) {
return this.buildVisionMessages(messages, parentMessageId);
}

View file

@ -16,14 +16,13 @@ const {
getModelMaxTokens,
genAzureChatCompletion,
} = require('~/utils');
const { truncateText, formatMessage, createContextHandlers, CUT_OFF_PROMPT } = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
const { handleOpenAIErrors } = require('./tools/util');
const spendTokens = require('~/models/spendTokens');
const { createLLM, RunManager } = require('./llm');
const ChatGPTClient = require('./ChatGPTClient');
const { isEnabled } = require('~/server/utils');
const { getFiles } = require('~/models/File');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
@ -92,7 +91,7 @@ class OpenAIClient extends BaseClient {
}
this.defaultVisionModel = this.options.visionModel ?? 'gpt-4-vision-preview';
this.checkVisionRequest(this.options.attachments);
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
if (OPENROUTER_API_KEY && !this.azure) {
@ -223,14 +222,19 @@ class OpenAIClient extends BaseClient {
* - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request.
* - Sets `this.isVisionModel` to `true` if vision request.
* - Deletes `this.modelOptions.stop` if vision request.
* @param {Array<Promise<MongoFile[]> | MongoFile[]> | Record<string, MongoFile[]>} attachments
* @param {MongoFile[]} attachments
*/
checkVisionRequest(attachments) {
const availableModels = this.options.modelsConfig?.[this.options.endpoint];
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
const visionModelAvailable = availableModels?.includes(this.defaultVisionModel);
if (attachments && visionModelAvailable && !this.isVisionModel) {
if (
attachments &&
attachments.some((file) => file?.type && file?.type?.includes('image')) &&
visionModelAvailable &&
!this.isVisionModel
) {
this.modelOptions.model = this.defaultVisionModel;
this.isVisionModel = true;
}
@ -366,7 +370,7 @@ class OpenAIClient extends BaseClient {
return {
chatGptLabel: this.options.chatGptLabel,
promptPrefix: this.options.promptPrefix,
resendImages: this.options.resendImages,
resendFiles: this.options.resendFiles,
imageDetail: this.options.imageDetail,
...this.modelOptions,
};
@ -380,54 +384,6 @@ class OpenAIClient extends BaseClient {
};
}
/**
*
* @param {TMessage[]} _messages
* @returns {TMessage[]}
*/
async addPreviousAttachments(_messages) {
if (!this.options.resendImages) {
return _messages;
}
/**
*
* @param {TMessage} message
*/
const processMessage = async (message) => {
if (!this.message_file_map) {
/** @type {Record<string, MongoFile[]> */
this.message_file_map = {};
}
const fileIds = message.files.map((file) => file.file_id);
const files = await getFiles({
file_id: { $in: fileIds },
});
await this.addImageURLs(message, files);
this.message_file_map[message.messageId] = files;
return message;
};
const promises = [];
for (const message of _messages) {
if (!message.files) {
promises.push(message);
continue;
}
promises.push(processMessage(message));
}
const messages = await Promise.all(promises);
this.checkVisionRequest(this.message_file_map);
return messages;
}
/**
*
* Adds image URLs to the message object and returns the files
@ -438,8 +394,7 @@ class OpenAIClient extends BaseClient {
*/
async addImageURLs(message, attachments) {
const { files, image_urls } = await encodeAndFormat(this.options.req, attachments);
message.image_urls = image_urls;
message.image_urls = image_urls.length ? image_urls : undefined;
return files;
}
@ -467,23 +422,9 @@ class OpenAIClient extends BaseClient {
let promptTokens;
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
if (promptPrefix) {
promptPrefix = `Instructions:\n${promptPrefix}`;
instructions = {
role: 'system',
name: 'instructions',
content: promptPrefix,
};
if (this.contextStrategy) {
instructions.tokenCount = this.getTokenCountForMessage(instructions);
}
}
if (this.options.attachments) {
const attachments = (await this.options.attachments).filter((file) =>
file.type.includes('image'),
);
const attachments = await this.options.attachments;
if (this.message_file_map) {
this.message_file_map[orderedMessages[orderedMessages.length - 1].messageId] = attachments;
@ -501,6 +442,13 @@ class OpenAIClient extends BaseClient {
this.options.attachments = files;
}
if (this.message_file_map) {
this.contextHandlers = createContextHandlers(
this.options.req,
orderedMessages[orderedMessages.length - 1].text,
);
}
const formattedMessages = orderedMessages.map((message, i) => {
const formattedMessage = formatMessage({
message,
@ -519,6 +467,11 @@ class OpenAIClient extends BaseClient {
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
orderedMessages[i].tokenCount += this.calculateImageTokenCost({
width: file.width,
height: file.height,
@ -530,6 +483,24 @@ class OpenAIClient extends BaseClient {
return formattedMessage;
});
if (this.contextHandlers) {
this.augmentedPrompt = await this.contextHandlers.createContext();
promptPrefix = this.augmentedPrompt + promptPrefix;
}
if (promptPrefix) {
promptPrefix = `Instructions:\n${promptPrefix.trim()}`;
instructions = {
role: 'system',
name: 'instructions',
content: promptPrefix,
};
if (this.contextStrategy) {
instructions.tokenCount = this.getTokenCountForMessage(instructions);
}
}
// TODO: need to handle interleaving instructions better
if (this.contextStrategy) {
({ payload, tokenCountMap, promptTokens, messages } = await this.handleContextStrategy({

View file

@ -0,0 +1,119 @@
const axios = require('axios');
function createContextHandlers(req, userMessageContent) {
if (!process.env.RAG_API_URL) {
return;
}
const queryPromises = [];
const processedFiles = [];
const processedIds = new Set();
const jwtToken = req.headers.authorization.split(' ')[1];
const processFile = async (file) => {
if (file.embedded && !processedIds.has(file.file_id)) {
try {
const promise = axios.post(
`${process.env.RAG_API_URL}/query`,
{
file_id: file.file_id,
query: userMessageContent,
k: 4,
},
{
headers: {
Authorization: `Bearer ${jwtToken}`,
'Content-Type': 'application/json',
},
},
);
queryPromises.push(promise);
processedFiles.push(file);
processedIds.add(file.file_id);
} catch (error) {
console.error(`Error processing file ${file.filename}:`, error);
}
}
};
const createContext = async () => {
try {
if (!queryPromises.length || !processedFiles.length) {
return '';
}
const resolvedQueries = await Promise.all(queryPromises);
const context = resolvedQueries
.map((queryResult, index) => {
const file = processedFiles[index];
const contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
<contextItem>
<![CDATA[${pageContent}]]>
</contextItem>
`;
})
.join('');
return `
<file>
<filename>${file.filename}</filename>
<context>
${contextItems}
</context>
</file>
`;
})
.join('');
const template = `The user has attached ${
processedFiles.length === 1 ? 'a' : processedFiles.length
} file${processedFiles.length !== 1 ? 's' : ''} to the conversation:
<files>
${processedFiles
.map(
(file) => `
<file>
<filename>${file.filename}</filename>
<type>${file.type}</type>
</file>
`,
)
.join('')}
</files>
A semantic search was executed with the user's message as the query, retrieving the following context inside <context></context> XML tags.
<context>
${context}
</context>
Use the context as your learned knowledge to better answer the user.
In your response, remember to follow these guidelines:
- If you don't know the answer, simply say that you don't know.
- If you are unsure how to answer, ask for clarification.
- Avoid mentioning that you obtained the information from the context.
Answer appropriately in the user's language.
`;
return template;
} catch (error) {
console.error('Error creating context:', error);
throw error; // Re-throw the error to propagate it to the caller
}
};
return {
processFile,
createContext,
};
}
module.exports = createContextHandlers;

View file

@ -4,6 +4,7 @@ const handleInputs = require('./handleInputs');
const instructions = require('./instructions');
const titlePrompts = require('./titlePrompts');
const truncateText = require('./truncateText');
const createContextHandlers = require('./createContextHandlers');
module.exports = {
...formatMessages,
@ -12,4 +13,5 @@ module.exports = {
...instructions,
...titlePrompts,
truncateText,
createContextHandlers,
};

View file

@ -69,7 +69,7 @@ const updateFileUsage = async (data) => {
const { file_id, inc = 1 } = data;
const updateOperation = {
$inc: { usage: inc },
$unset: { expiresAt: '' },
$unset: { expiresAt: '', temp_file_id: '' },
};
return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
};

View file

@ -70,10 +70,14 @@ const conversationPreset = {
type: String,
},
file_ids: { type: [{ type: String }], default: undefined },
// vision
// deprecated
resendImages: {
type: Boolean,
},
// files
resendFiles: {
type: Boolean,
},
imageDetail: {
type: String,
},

View file

@ -15,6 +15,7 @@ const mongoose = require('mongoose');
* @property {'file'} object - Type of object, always 'file'
* @property {string} type - Type of file
* @property {number} usage - Number of uses of the file
* @property {boolean} [embedded] - Whether or not the file is embedded in vector db
* @property {string} [source] - The source of the file
* @property {number} [width] - Optional width of the file
* @property {number} [height] - Optional height of the file
@ -61,6 +62,9 @@ const fileSchema = mongoose.Schema(
required: true,
default: 'file',
},
embedded: {
type: Boolean,
},
type: {
type: String,
required: true,

View file

@ -1,10 +1,10 @@
const buildOptions = (endpoint, parsedBody) => {
const { modelLabel, promptPrefix, resendImages, ...rest } = parsedBody;
const { modelLabel, promptPrefix, resendFiles, ...rest } = parsedBody;
const endpointOption = {
endpoint,
modelLabel,
promptPrefix,
resendImages,
resendFiles,
modelOptions: {
...rest,
},

View file

@ -1,11 +1,11 @@
const buildOptions = (endpoint, parsedBody, endpointType) => {
const { chatGptLabel, promptPrefix, resendImages, imageDetail, ...rest } = parsedBody;
const { chatGptLabel, promptPrefix, resendFiles, imageDetail, ...rest } = parsedBody;
const endpointOption = {
endpoint,
endpointType,
chatGptLabel,
promptPrefix,
resendImages,
resendFiles,
imageDetail,
modelOptions: {
...rest,

View file

@ -1,10 +1,10 @@
const buildOptions = (endpoint, parsedBody) => {
const { chatGptLabel, promptPrefix, resendImages, imageDetail, ...rest } = parsedBody;
const { chatGptLabel, promptPrefix, resendFiles, imageDetail, ...rest } = parsedBody;
const endpointOption = {
endpoint,
chatGptLabel,
promptPrefix,
resendImages,
resendFiles,
imageDetail,
modelOptions: {
...rest,

View file

@ -1,3 +1,6 @@
const fs = require('fs');
const path = require('path');
const axios = require('axios');
const fetch = require('node-fetch');
const { ref, uploadBytes, getDownloadURL, deleteObject } = require('firebase/storage');
const { getBufferMetadata } = require('~/server/utils');
@ -160,6 +163,18 @@ function extractFirebaseFilePath(urlString) {
* Throws an error if there is an issue with deletion.
*/
const deleteFirebaseFile = async (req, file) => {
if (file.embedded && process.env.RAG_API_URL) {
const jwtToken = req.headers.authorization.split(' ')[1];
axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: {
Authorization: `Bearer ${jwtToken}`,
'Content-Type': 'application/json',
accept: 'application/json',
},
data: [file.file_id],
});
}
const fileName = extractFirebaseFilePath(file.filepath);
if (!fileName.includes(req.user.id)) {
throw new Error('Invalid file path');
@ -167,10 +182,41 @@ const deleteFirebaseFile = async (req, file) => {
await deleteFile('', fileName);
};
/**
* Uploads a file to Firebase Storage.
*
* @param {Object} params - The params object.
* @param {Express.Request} params.req - The request object from Express. It should have a `user` property with an `id`
* representing the user.
* @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should
* have a `path` property that points to the location of the uploaded file.
* @param {string} params.file_id - The file ID.
*
* @returns {Promise<{ filepath: string, bytes: number }>}
* A promise that resolves to an object containing:
* - filepath: The download URL of the uploaded file.
* - bytes: The size of the uploaded file in bytes.
*/
async function uploadFileToFirebase({ req, file, file_id }) {
const inputFilePath = file.path;
const inputBuffer = await fs.promises.readFile(inputFilePath);
const bytes = Buffer.byteLength(inputBuffer);
const userId = req.user.id;
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const downloadURL = await saveBufferToFirebase({ userId, buffer: inputBuffer, fileName });
await fs.promises.unlink(inputFilePath);
return { filepath: downloadURL, bytes };
}
module.exports = {
deleteFile,
getFirebaseURL,
saveURLToFirebase,
deleteFirebaseFile,
uploadFileToFirebase,
saveBufferToFirebase,
};

View file

@ -188,7 +188,26 @@ const isValidPath = (req, base, subfolder, filepath) => {
* file path is invalid or if there is an error in deletion.
*/
const deleteLocalFile = async (req, file) => {
const { publicPath } = req.app.locals.paths;
const { publicPath, uploads } = req.app.locals.paths;
if (file.embedded && process.env.RAG_API_URL) {
const jwtToken = req.headers.authorization.split(' ')[1];
axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: {
Authorization: `Bearer ${jwtToken}`,
'Content-Type': 'application/json',
accept: 'application/json',
},
data: [file.file_id],
});
}
if (file.filepath.startsWith(`/uploads/${req.user.id}`)) {
const basePath = file.filepath.split('/uploads/')[1];
const filepath = path.join(uploads, basePath);
await fs.promises.unlink(filepath);
return;
}
const parts = file.filepath.split(path.sep);
const subfolder = parts[1];
const filepath = path.join(publicPath, file.filepath);
@ -200,6 +219,42 @@ const deleteLocalFile = async (req, file) => {
await fs.promises.unlink(filepath);
};
/**
* Uploads a file to the specified upload directory.
*
* @param {Object} params - The params object.
* @param {Object} params.req - The request object from Express. It should have a `user` property with an `id`
* representing the user, and an `app.locals.paths` object with an `uploads` path.
* @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should
* have a `path` property that points to the location of the uploaded file.
* @param {string} params.file_id - The file ID.
*
* @returns {Promise<{ filepath: string, bytes: number }>}
* A promise that resolves to an object containing:
* - filepath: The path where the file is saved.
* - bytes: The size of the file in bytes.
*/
async function uploadLocalFile({ req, file, file_id }) {
const inputFilePath = file.path;
const inputBuffer = await fs.promises.readFile(inputFilePath);
const bytes = Buffer.byteLength(inputBuffer);
const { uploads } = req.app.locals.paths;
const userPath = path.join(uploads, req.user.id);
if (!fs.existsSync(userPath)) {
fs.mkdirSync(userPath, { recursive: true });
}
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const newPath = path.join(userPath, fileName);
await fs.promises.writeFile(newPath, inputBuffer);
const filepath = path.posix.join('/', 'uploads', req.user.id, path.basename(newPath));
return { filepath, bytes };
}
module.exports = {
saveLocalFile,
saveLocalImage,
@ -207,4 +262,5 @@ module.exports = {
saveFileFromURL,
getLocalFileURL,
deleteLocalFile,
uploadLocalFile,
};

View file

@ -6,13 +6,14 @@ const { logger } = require('~/config');
/**
* Uploads a file that can be used across various OpenAI services.
*
* @param {Express.Request} req - The request object from Express. It should have a `user` property with an `id`
* @param {Object} params - The params object.
* @param {Express.Request} params.req - The request object from Express. It should have a `user` property with an `id`
* representing the user, and an `app.locals.paths` object with an `imageOutput` path.
* @param {Express.Multer.File} file - The file uploaded to the server via multer.
* @param {OpenAIClient} openai - The initialized OpenAI client.
* @param {Express.Multer.File} params.file - The file uploaded to the server via multer.
* @param {OpenAIClient} params.openai - The initialized OpenAI client.
* @returns {Promise<OpenAIFile>}
*/
async function uploadOpenAIFile(req, file, openai) {
async function uploadOpenAIFile({ req, file, openai }) {
const uploadedFile = await openai.files.create({
file: fs.createReadStream(file.path),
purpose: FilePurpose.Assistants,

View file

@ -39,6 +39,11 @@ async function encodeAndFormat(req, files, endpoint) {
for (let file of files) {
const source = file.source ?? FileSources.local;
if (!file.height) {
promises.push([file, null]);
continue;
}
if (!encodingMethods[source]) {
const { prepareImagePayload } = getStrategyFunctions(source);
if (!prepareImagePayload) {
@ -70,6 +75,24 @@ async function encodeAndFormat(req, files, endpoint) {
};
for (const [file, imageContent] of formattedImages) {
const fileMetadata = {
type: file.type,
file_id: file.file_id,
filepath: file.filepath,
filename: file.filename,
embedded: !!file.embedded,
};
if (file.height && file.width) {
fileMetadata.height = file.height;
fileMetadata.width = file.width;
}
if (!imageContent) {
result.files.push(fileMetadata);
continue;
}
const imagePart = {
type: 'image_url',
image_url: {
@ -93,15 +116,7 @@ async function encodeAndFormat(req, files, endpoint) {
}
result.image_urls.push(imagePart);
result.files.push({
file_id: file.file_id,
// filepath: file.filepath,
// filename: file.filename,
// type: file.type,
// height: file.height,
// width: file.width,
});
result.files.push(fileMetadata);
}
return result;
}

View file

@ -1,5 +1,6 @@
const path = require('path');
const { v4 } = require('uuid');
const axios = require('axios');
const mime = require('mime/lite');
const {
isUUID,
@ -189,12 +190,14 @@ const processImageFile = async ({ req, res, file, metadata }) => {
const source = req.app.locals.fileStrategy;
const { handleImageUpload } = getStrategyFunctions(source);
const { file_id, temp_file_id, endpoint } = metadata;
const { filepath, bytes, width, height } = await handleImageUpload({
req,
file,
file_id,
endpoint,
});
const result = await createFile(
{
user: req.user.id,
@ -266,13 +269,46 @@ const processFileUpload = async ({ req, res, file, metadata }) => {
const { handleFileUpload } = getStrategyFunctions(source);
const { file_id, temp_file_id } = metadata;
let embedded = false;
if (process.env.RAG_API_URL) {
try {
const jwtToken = req.headers.authorization.split(' ')[1];
const filepath = `./uploads/temp/${file.path.split('uploads/temp/')[1]}`;
const response = await axios.post(
`${process.env.RAG_API_URL}/embed`,
{
filename: file.originalname,
file_content_type: file.mimetype,
filepath,
file_id,
},
{
headers: {
Authorization: `Bearer ${jwtToken}`,
'Content-Type': 'application/json',
},
},
);
if (response.status === 200) {
embedded = true;
}
} catch (error) {
logger.error('Error embedding file', error);
throw new Error(error);
}
} else if (!isAssistantUpload) {
logger.error('RAG_API_URL not set, cannot support process file upload');
throw new Error('RAG_API_URL not set, cannot support process file upload');
}
/** @type {OpenAI | undefined} */
let openai;
if (source === FileSources.openai) {
({ openai } = await initializeClient({ req }));
}
const { id, bytes, filename, filepath } = await handleFileUpload(req, file, openai);
const { id, bytes, filename, filepath } = await handleFileUpload({ req, file, file_id, openai });
if (isAssistantUpload && !metadata.message_file) {
await openai.beta.assistants.files.create(metadata.assistant_id, {
@ -289,8 +325,9 @@ const processFileUpload = async ({ req, res, file, metadata }) => {
filepath: isAssistantUpload ? `${openai.baseURL}/files/${id}` : filepath,
filename: filename ?? file.originalname,
context: isAssistantUpload ? FileContext.assistants : FileContext.message_attachment,
source,
type: file.mimetype,
embedded,
source,
},
true,
);

View file

@ -5,6 +5,7 @@ const {
saveURLToFirebase,
deleteFirebaseFile,
saveBufferToFirebase,
uploadFileToFirebase,
uploadImageToFirebase,
processFirebaseAvatar,
} = require('./Firebase');
@ -14,6 +15,7 @@ const {
saveFileFromURL,
saveLocalBuffer,
deleteLocalFile,
uploadLocalFile,
uploadLocalImage,
prepareImagesLocal,
processLocalAvatar,
@ -32,6 +34,7 @@ const firebaseStrategy = () => ({
saveBuffer: saveBufferToFirebase,
prepareImagePayload: prepareImageURL,
processAvatar: processFirebaseAvatar,
handleFileUpload: uploadFileToFirebase,
handleImageUpload: uploadImageToFirebase,
});
@ -46,6 +49,7 @@ const localStrategy = () => ({
saveBuffer: saveLocalBuffer,
deleteFile: deleteLocalFile,
processAvatar: processLocalAvatar,
handleFileUpload: uploadLocalFile,
handleImageUpload: uploadLocalImage,
prepareImagePayload: prepareImagesLocal,
});