mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-21 02:40:14 +01:00
💾 feat: Anthropic Prompt Caching (#3670)
* wip: initial cache control implementation, add typing for transactions handling * feat: first pass of Anthropic Prompt Caching * feat: standardize stream usage as pass in when calculating token counts * feat: Add getCacheMultiplier function to calculate cache multiplier for different valueKeys and cacheTypes * chore: imports order * refactor: token usage recording in AnthropicClient, no need to "correct" as we have the correct amount * feat: more accurate token counting using stream usage data * feat: Improve token counting accuracy with stream usage data * refactor: ensure more accurate than not token estimations if custom instructions or files are not being resent with every request * refactor: cleanup updateUserMessageTokenCount to allow transactions to be as accurate as possible even if we shouldn't update user message token counts * ci: fix tests
This commit is contained in:
parent
9f4c516615
commit
a45b384bbc
17 changed files with 973 additions and 34 deletions
|
|
@ -54,10 +54,22 @@ class BaseClient {
|
|||
throw new Error('Subclasses attempted to call summarizeMessages without implementing it');
|
||||
}
|
||||
|
||||
async getTokenCountForResponse(response) {
|
||||
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', response);
|
||||
/**
|
||||
* Abstract method to get the token count for a message. Subclasses must implement this method.
|
||||
* @param {TMessage} responseMessage
|
||||
* @returns {number}
|
||||
*/
|
||||
getTokenCountForResponse(responseMessage) {
|
||||
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', responseMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Abstract method to record token usage. Subclasses must implement this method.
|
||||
* If a correction to the token usage is needed, the method should return an object with the corrected token counts.
|
||||
* @param {number} promptTokens
|
||||
* @param {number} completionTokens
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ promptTokens, completionTokens }) {
|
||||
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', {
|
||||
promptTokens,
|
||||
|
|
@ -536,13 +548,31 @@ class BaseClient {
|
|||
this.getTokenCountForResponse &&
|
||||
this.getTokenCount
|
||||
) {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
const completionTokens = this.getTokenCount(completion);
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens });
|
||||
let completionTokens;
|
||||
|
||||
/**
|
||||
* Metadata about input/output costs for the current message. The client
|
||||
* should provide a function to get the current stream usage metadata; if not,
|
||||
* use the legacy token estimations.
|
||||
* @type {StreamUsage | null} */
|
||||
const usage = this.getStreamUsage != null ? this.getStreamUsage() : null;
|
||||
|
||||
if (usage != null && Number(usage.output_tokens) > 0) {
|
||||
responseMessage.tokenCount = usage.output_tokens;
|
||||
completionTokens = responseMessage.tokenCount;
|
||||
await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts });
|
||||
} else {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
completionTokens = this.getTokenCount(completion);
|
||||
}
|
||||
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
|
||||
}
|
||||
|
||||
if (this.userMessagePromise) {
|
||||
await this.userMessagePromise;
|
||||
}
|
||||
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
|
|
@ -557,6 +587,66 @@ class BaseClient {
|
|||
return responseMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream usage should only be used for user message token count re-calculation if:
|
||||
* - The stream usage is available, with input tokens greater than 0,
|
||||
* - the client provides a function to calculate the current token count,
|
||||
* - files are being resent with every message (default behavior; or if `false`, with no attachments),
|
||||
* - the `promptPrefix` (custom instructions) is not set.
|
||||
*
|
||||
* In these cases, the legacy token estimations would be more accurate.
|
||||
*
|
||||
* TODO: included system messages in the `orderedMessages` accounting, potentially as a
|
||||
* separate message in the UI. ChatGPT does this through "hidden" system messages.
|
||||
* @param {object} params
|
||||
* @param {StreamUsage} params.usage
|
||||
* @param {Record<string, number>} params.tokenCountMap
|
||||
* @param {TMessage} params.userMessage
|
||||
* @param {object} params.opts
|
||||
*/
|
||||
async updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }) {
|
||||
/** @type {boolean} */
|
||||
const shouldUpdateCount =
|
||||
this.calculateCurrentTokenCount != null &&
|
||||
Number(usage.input_tokens) > 0 &&
|
||||
(this.options.resendFiles ||
|
||||
(!this.options.resendFiles && !this.options.attachments?.length)) &&
|
||||
!this.options.promptPrefix;
|
||||
|
||||
if (!shouldUpdateCount) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userMessageTokenCount = this.calculateCurrentTokenCount({
|
||||
currentMessageId: userMessage.messageId,
|
||||
tokenCountMap,
|
||||
usage,
|
||||
});
|
||||
|
||||
if (userMessageTokenCount === userMessage.tokenCount) {
|
||||
return;
|
||||
}
|
||||
|
||||
userMessage.tokenCount = userMessageTokenCount;
|
||||
/*
|
||||
Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
|
||||
*/
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessage,
|
||||
});
|
||||
}
|
||||
/*
|
||||
Note: we update the user message to be sure it gets the calculated token count;
|
||||
though `AskController` saves the user message, EditController does not
|
||||
*/
|
||||
await this.userMessagePromise;
|
||||
await this.updateMessageInDatabase({
|
||||
messageId: userMessage.messageId,
|
||||
tokenCount: userMessageTokenCount,
|
||||
});
|
||||
}
|
||||
|
||||
async loadHistory(conversationId, parentMessageId = null) {
|
||||
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
|
||||
|
||||
|
|
@ -644,6 +734,10 @@ class BaseClient {
|
|||
return { message: savedMessage, conversation };
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a message in the database.
|
||||
* @param {Partial<TMessage>} message
|
||||
*/
|
||||
async updateMessageInDatabase(message) {
|
||||
await updateMessage(this.options.req, message);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue