diff --git a/api/app/clients/prompts/addCacheControl.js b/api/app/clients/prompts/addCacheControl.js index 375e08e9b6..eed5910dc9 100644 --- a/api/app/clients/prompts/addCacheControl.js +++ b/api/app/clients/prompts/addCacheControl.js @@ -9,31 +9,31 @@ function addCacheControl(messages) { } const updatedMessages = [...messages]; - let userMessagesFound = 0; + let userMessagesModified = 0; - for (let i = updatedMessages.length - 1; i >= 0 && userMessagesFound < 2; i--) { - if (updatedMessages[i].role === 'user') { - if (typeof updatedMessages[i].content === 'string') { - updatedMessages[i] = { - ...updatedMessages[i], - content: [ - { - type: 'text', - text: updatedMessages[i].content, - cache_control: { type: 'ephemeral' }, - }, - ], - }; - } else if (Array.isArray(updatedMessages[i].content)) { - updatedMessages[i] = { - ...updatedMessages[i], - content: updatedMessages[i].content.map((item) => ({ - ...item, - cache_control: { type: 'ephemeral' }, - })), - }; + for (let i = updatedMessages.length - 1; i >= 0 && userMessagesModified < 2; i--) { + const message = updatedMessages[i]; + if (message.role !== 'user') { + continue; + } + + if (typeof message.content === 'string') { + message.content = [ + { + type: 'text', + text: message.content, + cache_control: { type: 'ephemeral' }, + }, + ]; + userMessagesModified++; + } else if (Array.isArray(message.content)) { + for (let j = message.content.length - 1; j >= 0; j--) { + if (message.content[j].type === 'text') { + message.content[j].cache_control = { type: 'ephemeral' }; + userMessagesModified++; + break; + } } - userMessagesFound++; } } diff --git a/api/app/clients/prompts/addCacheControl.spec.js b/api/app/clients/prompts/addCacheControl.spec.js index d95ccaf34b..c46ffd95e3 100644 --- a/api/app/clients/prompts/addCacheControl.spec.js +++ b/api/app/clients/prompts/addCacheControl.spec.js @@ -116,6 +116,7 @@ describe('addCacheControl', () => { content: [ { type: 'text', text: 'Hello' }, { type: 'image', url: 'http://example.com/image.jpg' }, + { type: 'text', text: 'This is an image' }, ], }, { role: 'assistant', content: 'Hi there' }, @@ -124,8 +125,9 @@ describe('addCacheControl', () => { const result = addCacheControl(messages); - expect(result[0].content[0].cache_control).toEqual({ type: 'ephemeral' }); - expect(result[0].content[1].cache_control).toEqual({ type: 'ephemeral' }); + expect(result[0].content[0]).not.toHaveProperty('cache_control'); + expect(result[0].content[1]).not.toHaveProperty('cache_control'); + expect(result[0].content[2].cache_control).toEqual({ type: 'ephemeral' }); expect(result[2].content[0]).toEqual({ type: 'text', text: 'How are you?', @@ -143,7 +145,6 @@ describe('addCacheControl', () => { ]; const result = addCacheControl(messages); - console.dir(result, { depth: null }); expect(result[0].content).toEqual('Hello'); expect(result[2].content[0]).toEqual({ @@ -161,4 +162,66 @@ describe('addCacheControl', () => { expect(result[1].content).toBe('Hi there'); expect(result[3].content).toBe('I\'m doing well, thanks!'); }); + + test('should handle edge case with multiple content types', () => { + const messages = [ + { + role: 'user', + content: [ + { + type: 'image', + source: { type: 'base64', media_type: 'image/png', data: 'some_base64_string' }, + }, + { + type: 'image', + source: { type: 'base64', media_type: 'image/png', data: 'another_base64_string' }, + }, + { type: 'text', text: 'what do all these images have in common' }, + ], + }, + { role: 'assistant', content: 'I see multiple images.' }, + { role: 'user', content: 'Correct!' }, + ]; + + const result = addCacheControl(messages); + + expect(result[0].content[0]).not.toHaveProperty('cache_control'); + expect(result[0].content[1]).not.toHaveProperty('cache_control'); + expect(result[0].content[2].cache_control).toEqual({ type: 'ephemeral' }); + expect(result[2].content[0]).toEqual({ + type: 'text', + text: 'Correct!', + cache_control: { type: 'ephemeral' }, + }); + }); + + test('should handle user message with no text block', () => { + const messages = [ + { + role: 'user', + content: [ + { + type: 'image', + source: { type: 'base64', media_type: 'image/png', data: 'some_base64_string' }, + }, + { + type: 'image', + source: { type: 'base64', media_type: 'image/png', data: 'another_base64_string' }, + }, + ], + }, + { role: 'assistant', content: 'I see two images.' }, + { role: 'user', content: 'Correct!' }, + ]; + + const result = addCacheControl(messages); + + expect(result[0].content[0]).not.toHaveProperty('cache_control'); + expect(result[0].content[1]).not.toHaveProperty('cache_control'); + expect(result[2].content[0]).toEqual({ + type: 'text', + text: 'Correct!', + cache_control: { type: 'ephemeral' }, + }); + }); });