📎 feat: Attachment Handling for v1/completions (#4205)

* refactor: add handling of attachments in v1/completions method

* ci: update OpenAIClient.test.js
This commit is contained in:
Danny Avila 2024-09-23 11:03:28 -04:00 committed by GitHub
parent 4328a25b6b
commit 17e59349ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 60 additions and 25 deletions

View file

@ -1,19 +1,21 @@
const Keyv = require('keyv'); const Keyv = require('keyv');
const crypto = require('crypto'); const crypto = require('crypto');
const { CohereClient } = require('cohere-ai');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { const {
ImageDetail,
EModelEndpoint, EModelEndpoint,
resolveHeaders, resolveHeaders,
CohereConstants, CohereConstants,
mapModelToAzureConfig, mapModelToAzureConfig,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { CohereClient } = require('cohere-ai'); const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { createContextHandlers } = require('./prompts');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { createCoherePayload } = require('./llm'); const { createCoherePayload } = require('./llm');
const { Agent, ProxyAgent } = require('undici'); const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { logger } = require('~/config'); const { logger } = require('~/config');
const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils');
const CHATGPT_MODEL = 'gpt-3.5-turbo'; const CHATGPT_MODEL = 'gpt-3.5-turbo';
const tokenizersCache = {}; const tokenizersCache = {};
@ -612,21 +614,66 @@ ${botMessage.message}
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) { async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim(); promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
// Handle attachments and create augmentedPrompt
if (this.options.attachments) {
const attachments = await this.options.attachments;
const lastMessage = messages[messages.length - 1];
if (this.message_file_map) {
this.message_file_map[lastMessage.messageId] = attachments;
} else {
this.message_file_map = {
[lastMessage.messageId]: attachments,
};
}
const files = await this.addImageURLs(lastMessage, attachments);
this.options.attachments = files;
this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
}
if (this.message_file_map) {
this.contextHandlers = createContextHandlers(
this.options.req,
messages[messages.length - 1].text,
);
}
// Calculate image token cost and process embedded files
messages.forEach((message, i) => {
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
messages[i].tokenCount =
(messages[i].tokenCount || 0) +
this.calculateImageTokenCost({
width: file.width,
height: file.height,
detail: this.options.imageDetail ?? ImageDetail.auto,
});
}
}
});
if (this.contextHandlers) {
this.augmentedPrompt = await this.contextHandlers.createContext();
promptPrefix = this.augmentedPrompt + promptPrefix;
}
if (promptPrefix) { if (promptPrefix) {
// If the prompt prefix doesn't end with the end token, add it. // If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) { if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`; promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
} }
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`; promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
} else {
const currentDateString = new Date().toLocaleDateString('en-us', {
year: 'numeric',
month: 'long',
day: 'numeric',
});
promptPrefix = `${this.startToken}Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date: ${currentDateString}${this.endToken}\n\n`;
} }
const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond. const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond.
const instructionsPayload = { const instructionsPayload = {
@ -714,10 +761,6 @@ ${botMessage.message}
this.maxResponseTokens, this.maxResponseTokens,
); );
if (this.options.debug) {
console.debug(`Prompt : ${prompt}`);
}
if (isChatGptModel) { if (isChatGptModel) {
return { prompt: [instructionsPayload, messagePayload], context }; return { prompt: [instructionsPayload, messagePayload], context };
} }

View file

@ -611,15 +611,7 @@ describe('OpenAIClient', () => {
expect(getCompletion).toHaveBeenCalled(); expect(getCompletion).toHaveBeenCalled();
expect(getCompletion.mock.calls.length).toBe(1); expect(getCompletion.mock.calls.length).toBe(1);
const currentDateString = new Date().toLocaleDateString('en-us', { expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n');
year: 'numeric',
month: 'long',
day: 'numeric',
});
expect(getCompletion.mock.calls[0][0]).toBe(
`||>Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date: ${currentDateString}\n\n||>User:\nHi mom!\n||>Assistant:\n`,
);
expect(fetchEventSource).toHaveBeenCalled(); expect(fetchEventSource).toHaveBeenCalled();
expect(fetchEventSource.mock.calls.length).toBe(1); expect(fetchEventSource.mock.calls.length).toBe(1);