diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 585cf05742..4b4919e334 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -53,7 +53,9 @@ class PluginsClient extends OpenAIClient { } getFunctionModelName(input) { - if (input.includes('gpt-3.5-turbo')) { + if (/-(?!0314)\d{4}/.test(input)) { + return input; + } else if (input.includes('gpt-3.5-turbo')) { return 'gpt-3.5-turbo'; } else if (input.includes('gpt-4')) { return 'gpt-4'; diff --git a/api/app/clients/specs/PluginsClient.test.js b/api/app/clients/specs/PluginsClient.test.js index 009167637a..b4e42b1fc5 100644 --- a/api/app/clients/specs/PluginsClient.test.js +++ b/api/app/clients/specs/PluginsClient.test.js @@ -144,4 +144,47 @@ describe('PluginsClient', () => { expect(chatMessages[0].text).toEqual(userMessage); }); }); + + describe('getFunctionModelName', () => { + let client; + + beforeEach(() => { + client = new PluginsClient('dummy_api_key'); + }); + + test('should return the input when it includes a dash followed by four digits', () => { + expect(client.getFunctionModelName('-1234')).toBe('-1234'); + expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview'); + }); + + test('should return the input for all function-capable models (`0613` models and above)', () => { + expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613'); + expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613'); + expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613'); + expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613'); + expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106'); + expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview'); + expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106'); + }); + + test('should return the corresponding model if input is non-function capable (`0314` models)', () => { + expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4'); + expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4'); + expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo'); + expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo'); + }); + + test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => { + expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo'); + }); + + test('should return "gpt-4" when the input includes "gpt-4"', () => { + expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4'); + }); + + test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => { + expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo'); + expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo'); + }); + }); });