🧠 feat: Implement O1 Model Support for Max Tokens Handling (#4376)

This commit is contained in:
Danny Avila 2024-10-10 08:36:36 +02:00 committed by GitHub
parent bdc2fd307f
commit 873e0473ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -68,6 +68,8 @@ class OpenAIClient extends BaseClient {
/** @type {OpenAIUsageMetadata | undefined} */ /** @type {OpenAIUsageMetadata | undefined} */
this.usage; this.usage;
/** @type {boolean|undefined} */
this.isO1Model;
} }
// TODO: PluginsClient calls this 3x, unneeded // TODO: PluginsClient calls this 3x, unneeded
@ -98,6 +100,8 @@ class OpenAIClient extends BaseClient {
this.options.modelOptions, this.options.modelOptions,
); );
this.isO1Model = /\bo1\b/i.test(this.modelOptions.model);
this.defaultVisionModel = this.options.visionModel ?? 'gpt-4-vision-preview'; this.defaultVisionModel = this.options.visionModel ?? 'gpt-4-vision-preview';
if (typeof this.options.attachments?.then === 'function') { if (typeof this.options.attachments?.then === 'function') {
this.options.attachments.then((attachments) => this.checkVisionRequest(attachments)); this.options.attachments.then((attachments) => this.checkVisionRequest(attachments));
@ -545,8 +549,7 @@ class OpenAIClient extends BaseClient {
promptPrefix = this.augmentedPrompt + promptPrefix; promptPrefix = this.augmentedPrompt + promptPrefix;
} }
const isO1Model = /\bo1\b/i.test(this.modelOptions.model); if (promptPrefix && this.isO1Model !== true) {
if (promptPrefix && !isO1Model) {
promptPrefix = `Instructions:\n${promptPrefix.trim()}`; promptPrefix = `Instructions:\n${promptPrefix.trim()}`;
instructions = { instructions = {
role: 'system', role: 'system',
@ -575,7 +578,7 @@ class OpenAIClient extends BaseClient {
}; };
/** EXPERIMENTAL */ /** EXPERIMENTAL */
if (promptPrefix && isO1Model) { if (promptPrefix && this.isO1Model === true) {
const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user'); const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user');
if (lastUserMessageIndex !== -1) { if (lastUserMessageIndex !== -1) {
payload[ payload[
@ -1227,6 +1230,11 @@ ${convo}
opts.defaultHeaders = { ...opts.defaultHeaders, 'api-key': this.apiKey }; opts.defaultHeaders = { ...opts.defaultHeaders, 'api-key': this.apiKey };
} }
if (this.isO1Model === true && modelOptions.max_tokens != null) {
modelOptions.max_completion_tokens = modelOptions.max_tokens;
delete modelOptions.max_tokens;
}
if (process.env.OPENAI_ORGANIZATION) { if (process.env.OPENAI_ORGANIZATION) {
opts.organization = process.env.OPENAI_ORGANIZATION; opts.organization = process.env.OPENAI_ORGANIZATION;
} }