fix(OpenAIClient): use official SDK to identify client and avoid false Rate Limit Error (#1161)

* chore: add eslint ignore unused var pattern

* feat: add extractBaseURL helper for valid OpenAI reverse proxies, with tests

* feat(OpenAIClient): add new chatCompletion using official OpenAI node SDK

* fix(ci): revert change to FORCE_PROMPT condition
This commit is contained in:
Danny Avila 2023-11-09 14:04:36 -05:00 committed by GitHub
parent ed3d7c9f80
commit 5ab9802aa9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 297 additions and 12 deletions

View file

@ -1,15 +1,17 @@
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const OpenAI = require('openai');
const { HttpsProxyAgent } = require('https-proxy-agent');
const ChatGPTClient = require('./ChatGPTClient');
const BaseClient = require('./BaseClient');
const { getModelMaxTokens, genAzureChatCompletion } = require('../../utils');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('../../utils');
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
const spendTokens = require('../../models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util');
const { isEnabled } = require('../../server/utils');
const { createLLM, RunManager } = require('./llm');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
const BaseClient = require('./BaseClient');
// Cache to store Tiktoken instances
const tokenizersCache = {};
@ -74,7 +76,7 @@ class OpenAIClient extends BaseClient {
}
const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
if (OPENROUTER_API_KEY) {
if (OPENROUTER_API_KEY && !this.azure) {
this.apiKey = OPENROUTER_API_KEY;
this.useOpenRouter = true;
}
@ -88,7 +90,11 @@ class OpenAIClient extends BaseClient {
this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt-');
this.isChatGptModel = this.isChatCompletion;
if (model.includes('text-davinci-003') || model.includes('instruct') || this.FORCE_PROMPT) {
if (
model.includes('text-davinci') ||
model.includes('gpt-3.5-turbo-instruct') ||
this.FORCE_PROMPT
) {
this.isChatCompletion = false;
this.isChatGptModel = false;
}
@ -134,7 +140,7 @@ class OpenAIClient extends BaseClient {
if (reverseProxy) {
this.completionsUrl = reverseProxy;
this.langchainProxy = reverseProxy.match(/.*v1/)?.[0];
this.langchainProxy = extractBaseURL(reverseProxy);
!this.langchainProxy &&
console.warn(`The reverse proxy URL ${reverseProxy} is not valid for Plugins.
The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions
@ -356,7 +362,9 @@ If your reverse proxy is compatible to OpenAI specs in every other way, it may s
let result = null;
let streamResult = null;
this.modelOptions.user = this.user;
if (typeof opts.onProgress === 'function') {
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
const useOldMethod = !!(this.azure || invalidBaseUrl);
if (typeof opts.onProgress === 'function' && useOldMethod) {
await this.getCompletion(
payload,
(progressMessage) => {
@ -399,6 +407,13 @@ If your reverse proxy is compatible to OpenAI specs in every other way, it may s
},
opts.abortController || new AbortController(),
);
} else if (typeof opts.onProgress === 'function') {
reply = await this.chatCompletion({
payload,
clientOptions: opts,
onProgress: opts.onProgress,
abortController: opts.abortController,
});
} else {
result = await this.getCompletion(
payload,
@ -669,6 +684,135 @@ ${convo}
content: response.text,
});
}
async chatCompletion({ payload, onProgress, clientOptions, abortController = null }) {
let error = null;
const errorCallback = (err) => (error = err);
let intermediateReply = '';
try {
if (!abortController) {
abortController = new AbortController();
}
const modelOptions = { ...this.modelOptions };
if (typeof onProgress === 'function') {
modelOptions.stream = true;
}
if (this.isChatGptModel) {
modelOptions.messages = payload;
} else {
modelOptions.prompt = payload;
}
const { debug } = this.options;
const url = extractBaseURL(this.completionsUrl);
if (debug) {
console.debug('baseURL', url);
console.debug('modelOptions', modelOptions);
}
const opts = {
baseURL: url,
};
if (this.useOpenRouter) {
opts.defaultHeaders = {
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
};
}
if (this.options.headers) {
opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers };
}
if (this.options.proxy) {
opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
let chatCompletion;
const openai = new OpenAI({
apiKey: this.apiKey,
...opts,
});
if (modelOptions.stream) {
const stream = await openai.beta.chat.completions
.stream({
...modelOptions,
stream: true,
})
.on('abort', () => {
/* Do nothing here */
})
.on('error', (err) => {
handleOpenAIErrors(err, errorCallback, 'stream');
});
for await (const chunk of stream) {
const token = chunk.choices[0]?.delta?.content || '';
intermediateReply += token;
onProgress(token);
if (abortController.signal.aborted) {
stream.controller.abort();
break;
}
}
chatCompletion = await stream.finalChatCompletion().catch((err) => {
handleOpenAIErrors(err, errorCallback, 'finalChatCompletion');
});
}
// regular completion
else {
chatCompletion = await openai.chat.completions
.create({
...modelOptions,
})
.catch((err) => {
handleOpenAIErrors(err, errorCallback, 'create');
});
}
if (!chatCompletion && error) {
throw new Error(error);
} else if (!chatCompletion) {
throw new Error('Chat completion failed');
}
const { message, finish_reason } = chatCompletion.choices[0];
if (chatCompletion && typeof clientOptions.addMetadata === 'function') {
clientOptions.addMetadata({ finish_reason });
}
return message.content;
} catch (err) {
if (
err?.message?.includes('abort') ||
(err instanceof OpenAI.APIError && err?.message?.includes('abort'))
) {
return '';
}
if (
err?.message?.includes('missing finish_reason') ||
(err instanceof OpenAI.OpenAIError && err?.message?.includes('missing finish_reason'))
) {
await abortController.abortCompletion();
return intermediateReply;
} else if (err instanceof OpenAI.APIError) {
console.log(err.name);
console.log(err.status);
console.log(err.headers);
if (intermediateReply) {
return intermediateReply;
} else {
throw err;
}
} else {
console.warn('[OpenAIClient.chatCompletion] Unhandled error type');
console.error(err);
throw err;
}
}
}
}
module.exports = OpenAIClient;