mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-01 08:08:49 +01:00
🗃️ feat: General File Support for OpenAI, Azure, Custom, Anthropic and Google (RAG) (#2143)
* refactor: re-purpose `resendImages` as `resendFiles` * refactor: re-purpose `resendImages` as `resendFiles` * feat: upload general files * feat: embed file during upload * feat: delete file embeddings on file deletion * chore(fileConfig): add epub+zip type * feat(encodeAndFormat): handle non-image files * feat(createContextHandlers): build context prompt from file attachments and successful RAG * fix: prevent non-temp files as well as embedded files to be deleted on new conversation * fix: remove temp_file_id on usage, prevent non-temp files as well as embedded files to be deleted on new conversation * fix: prevent non-temp files as well as embedded files to be deleted on new conversation * feat(OpenAI/Anthropic/Google): basic RAG support * fix: delete `resendFiles` only when true (Default) * refactor(RAG): update endpoints and pass JWT * fix(resendFiles): default values * fix(context/processFile): query unique ids only * feat: rag-api.yaml * feat: file upload improved ux for longer uploads * chore: await embed call and catch embedding errors * refactor: store augmentedPrompt in Client * refactor(processFileUpload): throw error if not assistant file upload * fix(useFileHandling): handle markdown empty mimetype issue * chore: necessary compose file changes
This commit is contained in:
parent
af347cccde
commit
f7761df52c
38 changed files with 683 additions and 261 deletions
|
|
@ -6,10 +6,9 @@ const {
|
|||
validateVisionModel,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { formatMessage, createContextHandlers } = require('./prompts');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { formatMessage } = require('./prompts');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -67,7 +66,7 @@ class AnthropicClient extends BaseClient {
|
|||
this.useMessages = this.isClaude3 || !!this.options.attachments;
|
||||
|
||||
this.defaultVisionModel = this.options.visionModel ?? 'claude-3-sonnet-20240229';
|
||||
this.checkVisionRequest(this.options.attachments);
|
||||
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
|
||||
|
||||
this.maxContextTokens =
|
||||
getModelMaxTokens(this.modelOptions.model, EModelEndpoint.anthropic) ?? 100000;
|
||||
|
|
@ -134,14 +133,19 @@ class AnthropicClient extends BaseClient {
|
|||
* - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request.
|
||||
* - Sets `this.isVisionModel` to `true` if vision request.
|
||||
* - Deletes `this.modelOptions.stop` if vision request.
|
||||
* @param {Array<Promise<MongoFile[]> | MongoFile[]> | Record<string, MongoFile[]>} attachments
|
||||
* @param {MongoFile[]} attachments
|
||||
*/
|
||||
checkVisionRequest(attachments) {
|
||||
const availableModels = this.options.modelsConfig?.[EModelEndpoint.anthropic];
|
||||
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
||||
|
||||
const visionModelAvailable = availableModels?.includes(this.defaultVisionModel);
|
||||
if (attachments && visionModelAvailable && !this.isVisionModel) {
|
||||
if (
|
||||
attachments &&
|
||||
attachments.some((file) => file?.type && file?.type?.includes('image')) &&
|
||||
visionModelAvailable &&
|
||||
!this.isVisionModel
|
||||
) {
|
||||
this.modelOptions.model = this.defaultVisionModel;
|
||||
this.isVisionModel = true;
|
||||
}
|
||||
|
|
@ -168,7 +172,7 @@ class AnthropicClient extends BaseClient {
|
|||
attachments,
|
||||
EModelEndpoint.anthropic,
|
||||
);
|
||||
message.image_urls = image_urls;
|
||||
message.image_urls = image_urls.length ? image_urls : undefined;
|
||||
return files;
|
||||
}
|
||||
|
||||
|
|
@ -186,54 +190,6 @@ class AnthropicClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage[]} _messages
|
||||
* @returns {TMessage[]}
|
||||
*/
|
||||
async addPreviousAttachments(_messages) {
|
||||
if (!this.options.resendImages) {
|
||||
return _messages;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage} message
|
||||
*/
|
||||
const processMessage = async (message) => {
|
||||
if (!this.message_file_map) {
|
||||
/** @type {Record<string, MongoFile[]> */
|
||||
this.message_file_map = {};
|
||||
}
|
||||
|
||||
const fileIds = message.files.map((file) => file.file_id);
|
||||
const files = await getFiles({
|
||||
file_id: { $in: fileIds },
|
||||
});
|
||||
|
||||
await this.addImageURLs(message, files);
|
||||
|
||||
this.message_file_map[message.messageId] = files;
|
||||
return message;
|
||||
};
|
||||
|
||||
const promises = [];
|
||||
|
||||
for (const message of _messages) {
|
||||
if (!message.files) {
|
||||
promises.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
promises.push(processMessage(message));
|
||||
}
|
||||
|
||||
const messages = await Promise.all(promises);
|
||||
|
||||
this.checkVisionRequest(this.message_file_map);
|
||||
return messages;
|
||||
}
|
||||
|
||||
async buildMessages(messages, parentMessageId) {
|
||||
const orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
|
|
@ -242,12 +198,13 @@ class AnthropicClient extends BaseClient {
|
|||
|
||||
logger.debug('[AnthropicClient] orderedMessages', { orderedMessages, parentMessageId });
|
||||
|
||||
if (!this.isVisionModel && this.options.attachments) {
|
||||
throw new Error('Attachments are only supported with the Claude 3 family of models');
|
||||
} else if (this.options.attachments) {
|
||||
const attachments = (await this.options.attachments).filter((file) =>
|
||||
file.type.includes('image'),
|
||||
);
|
||||
if (this.options.attachments) {
|
||||
const attachments = await this.options.attachments;
|
||||
const images = attachments.filter((file) => file.type.includes('image'));
|
||||
|
||||
if (images.length && !this.isVisionModel) {
|
||||
throw new Error('Images are only supported with the Claude 3 family of models');
|
||||
}
|
||||
|
||||
const latestMessage = orderedMessages[orderedMessages.length - 1];
|
||||
|
||||
|
|
@ -264,6 +221,13 @@ class AnthropicClient extends BaseClient {
|
|||
this.options.attachments = files;
|
||||
}
|
||||
|
||||
if (this.message_file_map) {
|
||||
this.contextHandlers = createContextHandlers(
|
||||
this.options.req,
|
||||
orderedMessages[orderedMessages.length - 1].text,
|
||||
);
|
||||
}
|
||||
|
||||
const formattedMessages = orderedMessages.map((message, i) => {
|
||||
const formattedMessage = this.useMessages
|
||||
? formatMessage({
|
||||
|
|
@ -285,6 +249,11 @@ class AnthropicClient extends BaseClient {
|
|||
if (this.message_file_map && this.message_file_map[message.messageId]) {
|
||||
const attachments = this.message_file_map[message.messageId];
|
||||
for (const file of attachments) {
|
||||
if (file.embedded) {
|
||||
this.contextHandlers?.processFile(file);
|
||||
continue;
|
||||
}
|
||||
|
||||
orderedMessages[i].tokenCount += this.calculateImageTokenCost({
|
||||
width: file.width,
|
||||
height: file.height,
|
||||
|
|
@ -296,6 +265,11 @@ class AnthropicClient extends BaseClient {
|
|||
return formattedMessage;
|
||||
});
|
||||
|
||||
if (this.contextHandlers) {
|
||||
this.augmentedPrompt = await this.contextHandlers.createContext();
|
||||
this.options.promptPrefix = this.augmentedPrompt + (this.options.promptPrefix ?? '');
|
||||
}
|
||||
|
||||
let { context: messagesInWindow, remainingContextTokens } =
|
||||
await this.getMessagesWithinTokenLimit(formattedMessages);
|
||||
|
||||
|
|
@ -389,7 +363,7 @@ class AnthropicClient extends BaseClient {
|
|||
let isEdited = lastAuthor === this.assistantLabel;
|
||||
const promptSuffix = isEdited ? '' : `${promptPrefix}${this.assistantLabel}\n`;
|
||||
let currentTokenCount =
|
||||
isEdited || this.useMEssages
|
||||
isEdited || this.useMessages
|
||||
? this.getTokenCount(promptPrefix)
|
||||
: this.getTokenCount(promptSuffix);
|
||||
|
||||
|
|
@ -663,6 +637,7 @@ class AnthropicClient extends BaseClient {
|
|||
return {
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
modelLabel: this.options.modelLabel,
|
||||
resendFiles: this.options.resendFiles,
|
||||
...this.modelOptions,
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
|||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -46,10 +47,6 @@ class BaseClient {
|
|||
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', response);
|
||||
}
|
||||
|
||||
async addPreviousAttachments(messages) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
async recordTokenUsage({ promptTokens, completionTokens }) {
|
||||
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', {
|
||||
promptTokens,
|
||||
|
|
@ -683,6 +680,54 @@ class BaseClient {
|
|||
|
||||
return await this.sendCompletion(payload, opts);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage[]} _messages
|
||||
* @returns {Promise<TMessage[]>}
|
||||
*/
|
||||
async addPreviousAttachments(_messages) {
|
||||
if (!this.options.resendFiles) {
|
||||
return _messages;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage} message
|
||||
*/
|
||||
const processMessage = async (message) => {
|
||||
if (!this.message_file_map) {
|
||||
/** @type {Record<string, MongoFile[]> */
|
||||
this.message_file_map = {};
|
||||
}
|
||||
|
||||
const fileIds = message.files.map((file) => file.file_id);
|
||||
const files = await getFiles({
|
||||
file_id: { $in: fileIds },
|
||||
});
|
||||
|
||||
await this.addImageURLs(message, files);
|
||||
|
||||
this.message_file_map[message.messageId] = files;
|
||||
return message;
|
||||
};
|
||||
|
||||
const promises = [];
|
||||
|
||||
for (const message of _messages) {
|
||||
if (!message.files) {
|
||||
promises.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
promises.push(processMessage(message));
|
||||
}
|
||||
|
||||
const messages = await Promise.all(promises);
|
||||
|
||||
this.checkVisionRequest(Object.values(this.message_file_map ?? {}).flat());
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = BaseClient;
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ const {
|
|||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { formatMessage, createContextHandlers } = require('./prompts');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { formatMessage } = require('./prompts');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -124,24 +124,7 @@ class GoogleClient extends BaseClient {
|
|||
// stop: modelOptions.stop // no stop method for now
|
||||
};
|
||||
|
||||
/* Validation vision request */
|
||||
this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision';
|
||||
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
|
||||
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
||||
|
||||
if (
|
||||
this.options.attachments &&
|
||||
availableModels?.includes(this.defaultVisionModel) &&
|
||||
!this.isVisionModel
|
||||
) {
|
||||
this.modelOptions.model = this.defaultVisionModel;
|
||||
this.isVisionModel = true;
|
||||
}
|
||||
|
||||
if (this.isVisionModel && !this.options.attachments) {
|
||||
this.modelOptions.model = 'gemini-pro';
|
||||
this.isVisionModel = false;
|
||||
}
|
||||
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
|
||||
|
||||
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
|
||||
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
|
||||
|
|
@ -230,6 +213,33 @@ class GoogleClient extends BaseClient {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Checks if the model is a vision model based on request attachments and sets the appropriate options:
|
||||
* @param {MongoFile[]} attachments
|
||||
*/
|
||||
checkVisionRequest(attachments) {
|
||||
/* Validation vision request */
|
||||
this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision';
|
||||
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
|
||||
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
||||
|
||||
if (
|
||||
attachments &&
|
||||
attachments.some((file) => file?.type && file?.type?.includes('image')) &&
|
||||
availableModels?.includes(this.defaultVisionModel) &&
|
||||
!this.isVisionModel
|
||||
) {
|
||||
this.modelOptions.model = this.defaultVisionModel;
|
||||
this.isVisionModel = true;
|
||||
}
|
||||
|
||||
if (this.isVisionModel && !attachments) {
|
||||
this.modelOptions.model = 'gemini-pro';
|
||||
this.isVisionModel = false;
|
||||
}
|
||||
}
|
||||
|
||||
formatMessages() {
|
||||
return ((message) => ({
|
||||
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
|
||||
|
|
@ -237,18 +247,45 @@ class GoogleClient extends BaseClient {
|
|||
})).bind(this);
|
||||
}
|
||||
|
||||
async buildVisionMessages(messages = [], parentMessageId) {
|
||||
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
|
||||
const attachments = await this.options.attachments;
|
||||
/**
|
||||
*
|
||||
* Adds image URLs to the message object and returns the files
|
||||
*
|
||||
* @param {TMessage[]} messages
|
||||
* @param {MongoFile[]} files
|
||||
* @returns {Promise<MongoFile[]>}
|
||||
*/
|
||||
async addImageURLs(message, attachments) {
|
||||
const { files, image_urls } = await encodeAndFormat(
|
||||
this.options.req,
|
||||
attachments.filter((file) => file.type.includes('image')),
|
||||
attachments,
|
||||
EModelEndpoint.google,
|
||||
);
|
||||
message.image_urls = image_urls.length ? image_urls : undefined;
|
||||
return files;
|
||||
}
|
||||
|
||||
async buildVisionMessages(messages = [], parentMessageId) {
|
||||
const attachments = await this.options.attachments;
|
||||
const latestMessage = { ...messages[messages.length - 1] };
|
||||
this.contextHandlers = createContextHandlers(this.options.req, latestMessage.text);
|
||||
|
||||
if (this.contextHandlers) {
|
||||
for (const file of attachments) {
|
||||
if (file.embedded) {
|
||||
this.contextHandlers?.processFile(file);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
this.augmentedPrompt = await this.contextHandlers.createContext();
|
||||
this.options.promptPrefix = this.augmentedPrompt + this.options.promptPrefix;
|
||||
}
|
||||
|
||||
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
|
||||
|
||||
const files = await this.addImageURLs(latestMessage, attachments);
|
||||
|
||||
latestMessage.image_urls = image_urls;
|
||||
this.options.attachments = files;
|
||||
|
||||
latestMessage.text = prompt;
|
||||
|
|
@ -275,7 +312,7 @@ class GoogleClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
|
||||
if (this.options.attachments) {
|
||||
if (this.options.attachments && this.isGenerativeModel) {
|
||||
return this.buildVisionMessages(messages, parentMessageId);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,14 +16,13 @@ const {
|
|||
getModelMaxTokens,
|
||||
genAzureChatCompletion,
|
||||
} = require('~/utils');
|
||||
const { truncateText, formatMessage, createContextHandlers, CUT_OFF_PROMPT } = require('./prompts');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
const { createLLM, RunManager } = require('./llm');
|
||||
const ChatGPTClient = require('./ChatGPTClient');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { summaryBuffer } = require('./memory');
|
||||
const { runTitleChain } = require('./chains');
|
||||
const { tokenSplit } = require('./document');
|
||||
|
|
@ -92,7 +91,7 @@ class OpenAIClient extends BaseClient {
|
|||
}
|
||||
|
||||
this.defaultVisionModel = this.options.visionModel ?? 'gpt-4-vision-preview';
|
||||
this.checkVisionRequest(this.options.attachments);
|
||||
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
|
||||
|
||||
const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
|
||||
if (OPENROUTER_API_KEY && !this.azure) {
|
||||
|
|
@ -223,14 +222,19 @@ class OpenAIClient extends BaseClient {
|
|||
* - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request.
|
||||
* - Sets `this.isVisionModel` to `true` if vision request.
|
||||
* - Deletes `this.modelOptions.stop` if vision request.
|
||||
* @param {Array<Promise<MongoFile[]> | MongoFile[]> | Record<string, MongoFile[]>} attachments
|
||||
* @param {MongoFile[]} attachments
|
||||
*/
|
||||
checkVisionRequest(attachments) {
|
||||
const availableModels = this.options.modelsConfig?.[this.options.endpoint];
|
||||
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
||||
|
||||
const visionModelAvailable = availableModels?.includes(this.defaultVisionModel);
|
||||
if (attachments && visionModelAvailable && !this.isVisionModel) {
|
||||
if (
|
||||
attachments &&
|
||||
attachments.some((file) => file?.type && file?.type?.includes('image')) &&
|
||||
visionModelAvailable &&
|
||||
!this.isVisionModel
|
||||
) {
|
||||
this.modelOptions.model = this.defaultVisionModel;
|
||||
this.isVisionModel = true;
|
||||
}
|
||||
|
|
@ -366,7 +370,7 @@ class OpenAIClient extends BaseClient {
|
|||
return {
|
||||
chatGptLabel: this.options.chatGptLabel,
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
resendImages: this.options.resendImages,
|
||||
resendFiles: this.options.resendFiles,
|
||||
imageDetail: this.options.imageDetail,
|
||||
...this.modelOptions,
|
||||
};
|
||||
|
|
@ -380,54 +384,6 @@ class OpenAIClient extends BaseClient {
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage[]} _messages
|
||||
* @returns {TMessage[]}
|
||||
*/
|
||||
async addPreviousAttachments(_messages) {
|
||||
if (!this.options.resendImages) {
|
||||
return _messages;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage} message
|
||||
*/
|
||||
const processMessage = async (message) => {
|
||||
if (!this.message_file_map) {
|
||||
/** @type {Record<string, MongoFile[]> */
|
||||
this.message_file_map = {};
|
||||
}
|
||||
|
||||
const fileIds = message.files.map((file) => file.file_id);
|
||||
const files = await getFiles({
|
||||
file_id: { $in: fileIds },
|
||||
});
|
||||
|
||||
await this.addImageURLs(message, files);
|
||||
|
||||
this.message_file_map[message.messageId] = files;
|
||||
return message;
|
||||
};
|
||||
|
||||
const promises = [];
|
||||
|
||||
for (const message of _messages) {
|
||||
if (!message.files) {
|
||||
promises.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
promises.push(processMessage(message));
|
||||
}
|
||||
|
||||
const messages = await Promise.all(promises);
|
||||
|
||||
this.checkVisionRequest(this.message_file_map);
|
||||
return messages;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Adds image URLs to the message object and returns the files
|
||||
|
|
@ -438,8 +394,7 @@ class OpenAIClient extends BaseClient {
|
|||
*/
|
||||
async addImageURLs(message, attachments) {
|
||||
const { files, image_urls } = await encodeAndFormat(this.options.req, attachments);
|
||||
|
||||
message.image_urls = image_urls;
|
||||
message.image_urls = image_urls.length ? image_urls : undefined;
|
||||
return files;
|
||||
}
|
||||
|
||||
|
|
@ -467,23 +422,9 @@ class OpenAIClient extends BaseClient {
|
|||
let promptTokens;
|
||||
|
||||
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
|
||||
if (promptPrefix) {
|
||||
promptPrefix = `Instructions:\n${promptPrefix}`;
|
||||
instructions = {
|
||||
role: 'system',
|
||||
name: 'instructions',
|
||||
content: promptPrefix,
|
||||
};
|
||||
|
||||
if (this.contextStrategy) {
|
||||
instructions.tokenCount = this.getTokenCountForMessage(instructions);
|
||||
}
|
||||
}
|
||||
|
||||
if (this.options.attachments) {
|
||||
const attachments = (await this.options.attachments).filter((file) =>
|
||||
file.type.includes('image'),
|
||||
);
|
||||
const attachments = await this.options.attachments;
|
||||
|
||||
if (this.message_file_map) {
|
||||
this.message_file_map[orderedMessages[orderedMessages.length - 1].messageId] = attachments;
|
||||
|
|
@ -501,6 +442,13 @@ class OpenAIClient extends BaseClient {
|
|||
this.options.attachments = files;
|
||||
}
|
||||
|
||||
if (this.message_file_map) {
|
||||
this.contextHandlers = createContextHandlers(
|
||||
this.options.req,
|
||||
orderedMessages[orderedMessages.length - 1].text,
|
||||
);
|
||||
}
|
||||
|
||||
const formattedMessages = orderedMessages.map((message, i) => {
|
||||
const formattedMessage = formatMessage({
|
||||
message,
|
||||
|
|
@ -519,6 +467,11 @@ class OpenAIClient extends BaseClient {
|
|||
if (this.message_file_map && this.message_file_map[message.messageId]) {
|
||||
const attachments = this.message_file_map[message.messageId];
|
||||
for (const file of attachments) {
|
||||
if (file.embedded) {
|
||||
this.contextHandlers?.processFile(file);
|
||||
continue;
|
||||
}
|
||||
|
||||
orderedMessages[i].tokenCount += this.calculateImageTokenCost({
|
||||
width: file.width,
|
||||
height: file.height,
|
||||
|
|
@ -530,6 +483,24 @@ class OpenAIClient extends BaseClient {
|
|||
return formattedMessage;
|
||||
});
|
||||
|
||||
if (this.contextHandlers) {
|
||||
this.augmentedPrompt = await this.contextHandlers.createContext();
|
||||
promptPrefix = this.augmentedPrompt + promptPrefix;
|
||||
}
|
||||
|
||||
if (promptPrefix) {
|
||||
promptPrefix = `Instructions:\n${promptPrefix.trim()}`;
|
||||
instructions = {
|
||||
role: 'system',
|
||||
name: 'instructions',
|
||||
content: promptPrefix,
|
||||
};
|
||||
|
||||
if (this.contextStrategy) {
|
||||
instructions.tokenCount = this.getTokenCountForMessage(instructions);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: need to handle interleaving instructions better
|
||||
if (this.contextStrategy) {
|
||||
({ payload, tokenCountMap, promptTokens, messages } = await this.handleContextStrategy({
|
||||
|
|
|
|||
119
api/app/clients/prompts/createContextHandlers.js
Normal file
119
api/app/clients/prompts/createContextHandlers.js
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
const axios = require('axios');
|
||||
|
||||
function createContextHandlers(req, userMessageContent) {
|
||||
if (!process.env.RAG_API_URL) {
|
||||
return;
|
||||
}
|
||||
|
||||
const queryPromises = [];
|
||||
const processedFiles = [];
|
||||
const processedIds = new Set();
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
|
||||
const processFile = async (file) => {
|
||||
if (file.embedded && !processedIds.has(file.file_id)) {
|
||||
try {
|
||||
const promise = axios.post(
|
||||
`${process.env.RAG_API_URL}/query`,
|
||||
{
|
||||
file_id: file.file_id,
|
||||
query: userMessageContent,
|
||||
k: 4,
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${jwtToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
queryPromises.push(promise);
|
||||
processedFiles.push(file);
|
||||
processedIds.add(file.file_id);
|
||||
} catch (error) {
|
||||
console.error(`Error processing file ${file.filename}:`, error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const createContext = async () => {
|
||||
try {
|
||||
if (!queryPromises.length || !processedFiles.length) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const resolvedQueries = await Promise.all(queryPromises);
|
||||
|
||||
const context = resolvedQueries
|
||||
.map((queryResult, index) => {
|
||||
const file = processedFiles[index];
|
||||
const contextItems = queryResult.data
|
||||
.map((item) => {
|
||||
const pageContent = item[0].page_content;
|
||||
return `
|
||||
<contextItem>
|
||||
<![CDATA[${pageContent}]]>
|
||||
</contextItem>
|
||||
`;
|
||||
})
|
||||
.join('');
|
||||
|
||||
return `
|
||||
<file>
|
||||
<filename>${file.filename}</filename>
|
||||
<context>
|
||||
${contextItems}
|
||||
</context>
|
||||
</file>
|
||||
`;
|
||||
})
|
||||
.join('');
|
||||
|
||||
const template = `The user has attached ${
|
||||
processedFiles.length === 1 ? 'a' : processedFiles.length
|
||||
} file${processedFiles.length !== 1 ? 's' : ''} to the conversation:
|
||||
|
||||
<files>
|
||||
${processedFiles
|
||||
.map(
|
||||
(file) => `
|
||||
<file>
|
||||
<filename>${file.filename}</filename>
|
||||
<type>${file.type}</type>
|
||||
</file>
|
||||
`,
|
||||
)
|
||||
.join('')}
|
||||
</files>
|
||||
|
||||
A semantic search was executed with the user's message as the query, retrieving the following context inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
${context}
|
||||
</context>
|
||||
|
||||
Use the context as your learned knowledge to better answer the user.
|
||||
|
||||
In your response, remember to follow these guidelines:
|
||||
- If you don't know the answer, simply say that you don't know.
|
||||
- If you are unsure how to answer, ask for clarification.
|
||||
- Avoid mentioning that you obtained the information from the context.
|
||||
|
||||
Answer appropriately in the user's language.
|
||||
`;
|
||||
|
||||
return template;
|
||||
} catch (error) {
|
||||
console.error('Error creating context:', error);
|
||||
throw error; // Re-throw the error to propagate it to the caller
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
processFile,
|
||||
createContext,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = createContextHandlers;
|
||||
|
|
@ -4,6 +4,7 @@ const handleInputs = require('./handleInputs');
|
|||
const instructions = require('./instructions');
|
||||
const titlePrompts = require('./titlePrompts');
|
||||
const truncateText = require('./truncateText');
|
||||
const createContextHandlers = require('./createContextHandlers');
|
||||
|
||||
module.exports = {
|
||||
...formatMessages,
|
||||
|
|
@ -12,4 +13,5 @@ module.exports = {
|
|||
...instructions,
|
||||
...titlePrompts,
|
||||
truncateText,
|
||||
createContextHandlers,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue