💾 feat: Anthropic Prompt Caching (#3670)

* wip: initial cache control implementation, add typing for transactions handling

* feat: first pass of Anthropic Prompt Caching

* feat: standardize stream usage as pass in when calculating token counts

* feat: Add getCacheMultiplier function to calculate cache multiplier for different valueKeys and cacheTypes

* chore: imports order

* refactor: token usage recording in AnthropicClient, no need to "correct" as we have the correct amount

* feat: more accurate token counting using stream usage data

* feat: Improve token counting accuracy with stream usage data

* refactor: ensure more accurate than not token estimations if custom instructions or files are not being resent with every request

* refactor: cleanup updateUserMessageTokenCount to allow transactions to be as accurate as possible even if we shouldn't update user message token counts

* ci: fix tests
This commit is contained in:
Danny Avila 2024-08-17 03:24:09 -04:00 committed by GitHub
parent 9f4c516615
commit a45b384bbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 973 additions and 34 deletions

View file

@ -54,10 +54,22 @@ class BaseClient {
throw new Error('Subclasses attempted to call summarizeMessages without implementing it');
}
async getTokenCountForResponse(response) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', response);
/**
* Abstract method to get the token count for a message. Subclasses must implement this method.
* @param {TMessage} responseMessage
* @returns {number}
*/
getTokenCountForResponse(responseMessage) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', responseMessage);
}
/**
* Abstract method to record token usage. Subclasses must implement this method.
* If a correction to the token usage is needed, the method should return an object with the corrected token counts.
* @param {number} promptTokens
* @param {number} completionTokens
* @returns {Promise<void>}
*/
async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', {
promptTokens,
@ -536,13 +548,31 @@ class BaseClient {
this.getTokenCountForResponse &&
this.getTokenCount
) {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
const completionTokens = this.getTokenCount(completion);
await this.recordTokenUsage({ promptTokens, completionTokens });
let completionTokens;
/**
* Metadata about input/output costs for the current message. The client
* should provide a function to get the current stream usage metadata; if not,
* use the legacy token estimations.
* @type {StreamUsage | null} */
const usage = this.getStreamUsage != null ? this.getStreamUsage() : null;
if (usage != null && Number(usage.output_tokens) > 0) {
responseMessage.tokenCount = usage.output_tokens;
completionTokens = responseMessage.tokenCount;
await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts });
} else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = this.getTokenCount(completion);
}
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
}
if (this.userMessagePromise) {
await this.userMessagePromise;
}
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
@ -557,6 +587,66 @@ class BaseClient {
return responseMessage;
}
/**
* Stream usage should only be used for user message token count re-calculation if:
* - The stream usage is available, with input tokens greater than 0,
* - the client provides a function to calculate the current token count,
* - files are being resent with every message (default behavior; or if `false`, with no attachments),
* - the `promptPrefix` (custom instructions) is not set.
*
* In these cases, the legacy token estimations would be more accurate.
*
* TODO: included system messages in the `orderedMessages` accounting, potentially as a
* separate message in the UI. ChatGPT does this through "hidden" system messages.
* @param {object} params
* @param {StreamUsage} params.usage
* @param {Record<string, number>} params.tokenCountMap
* @param {TMessage} params.userMessage
* @param {object} params.opts
*/
async updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }) {
/** @type {boolean} */
const shouldUpdateCount =
this.calculateCurrentTokenCount != null &&
Number(usage.input_tokens) > 0 &&
(this.options.resendFiles ||
(!this.options.resendFiles && !this.options.attachments?.length)) &&
!this.options.promptPrefix;
if (!shouldUpdateCount) {
return;
}
const userMessageTokenCount = this.calculateCurrentTokenCount({
currentMessageId: userMessage.messageId,
tokenCountMap,
usage,
});
if (userMessageTokenCount === userMessage.tokenCount) {
return;
}
userMessage.tokenCount = userMessageTokenCount;
/*
Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
*/
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessage,
});
}
/*
Note: we update the user message to be sure it gets the calculated token count;
though `AskController` saves the user message, EditController does not
*/
await this.userMessagePromise;
await this.updateMessageInDatabase({
messageId: userMessage.messageId,
tokenCount: userMessageTokenCount,
});
}
async loadHistory(conversationId, parentMessageId = null) {
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
@ -644,6 +734,10 @@ class BaseClient {
return { message: savedMessage, conversation };
}
/**
* Update a message in the database.
* @param {Partial<TMessage>} message
*/
async updateMessageInDatabase(message) {
await updateMessage(this.options.req, message);
}