fix(OpenAIClient): use official SDK to identify client and avoid false Rate Limit Error (#1161)

* chore: add eslint ignore unused var pattern

* feat: add extractBaseURL helper for valid OpenAI reverse proxies, with tests

* feat(OpenAIClient): add new chatCompletion using official OpenAI node SDK

* fix(ci): revert change to FORCE_PROMPT condition
This commit is contained in:
Danny Avila 2023-11-09 14:04:36 -05:00 committed by GitHub
parent ed3d7c9f80
commit 5ab9802aa9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 297 additions and 12 deletions

View file

@ -61,6 +61,7 @@ module.exports = {
'no-restricted-syntax': 'off', 'no-restricted-syntax': 'off',
'react/prop-types': ['off'], 'react/prop-types': ['off'],
'react/display-name': ['off'], 'react/display-name': ['off'],
'no-unused-vars': ['error', { varsIgnorePattern: '^_' }],
quotes: ['error', 'single'], quotes: ['error', 'single'],
}, },
overrides: [ overrides: [

View file

@ -1,15 +1,17 @@
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const OpenAI = require('openai');
const { HttpsProxyAgent } = require('https-proxy-agent'); const { HttpsProxyAgent } = require('https-proxy-agent');
const ChatGPTClient = require('./ChatGPTClient'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const BaseClient = require('./BaseClient'); const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('../../utils');
const { getModelMaxTokens, genAzureChatCompletion } = require('../../utils');
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts'); const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
const spendTokens = require('../../models/spendTokens'); const spendTokens = require('../../models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util');
const { isEnabled } = require('../../server/utils'); const { isEnabled } = require('../../server/utils');
const { createLLM, RunManager } = require('./llm'); const { createLLM, RunManager } = require('./llm');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory'); const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains'); const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document'); const { tokenSplit } = require('./document');
const BaseClient = require('./BaseClient');
// Cache to store Tiktoken instances // Cache to store Tiktoken instances
const tokenizersCache = {}; const tokenizersCache = {};
@ -74,7 +76,7 @@ class OpenAIClient extends BaseClient {
} }
const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {}; const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
if (OPENROUTER_API_KEY) { if (OPENROUTER_API_KEY && !this.azure) {
this.apiKey = OPENROUTER_API_KEY; this.apiKey = OPENROUTER_API_KEY;
this.useOpenRouter = true; this.useOpenRouter = true;
} }
@ -88,7 +90,11 @@ class OpenAIClient extends BaseClient {
this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt-'); this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt-');
this.isChatGptModel = this.isChatCompletion; this.isChatGptModel = this.isChatCompletion;
if (model.includes('text-davinci-003') || model.includes('instruct') || this.FORCE_PROMPT) { if (
model.includes('text-davinci') ||
model.includes('gpt-3.5-turbo-instruct') ||
this.FORCE_PROMPT
) {
this.isChatCompletion = false; this.isChatCompletion = false;
this.isChatGptModel = false; this.isChatGptModel = false;
} }
@ -134,7 +140,7 @@ class OpenAIClient extends BaseClient {
if (reverseProxy) { if (reverseProxy) {
this.completionsUrl = reverseProxy; this.completionsUrl = reverseProxy;
this.langchainProxy = reverseProxy.match(/.*v1/)?.[0]; this.langchainProxy = extractBaseURL(reverseProxy);
!this.langchainProxy && !this.langchainProxy &&
console.warn(`The reverse proxy URL ${reverseProxy} is not valid for Plugins. console.warn(`The reverse proxy URL ${reverseProxy} is not valid for Plugins.
The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions
@ -356,7 +362,9 @@ If your reverse proxy is compatible to OpenAI specs in every other way, it may s
let result = null; let result = null;
let streamResult = null; let streamResult = null;
this.modelOptions.user = this.user; this.modelOptions.user = this.user;
if (typeof opts.onProgress === 'function') { const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
const useOldMethod = !!(this.azure || invalidBaseUrl);
if (typeof opts.onProgress === 'function' && useOldMethod) {
await this.getCompletion( await this.getCompletion(
payload, payload,
(progressMessage) => { (progressMessage) => {
@ -399,6 +407,13 @@ If your reverse proxy is compatible to OpenAI specs in every other way, it may s
}, },
opts.abortController || new AbortController(), opts.abortController || new AbortController(),
); );
} else if (typeof opts.onProgress === 'function') {
reply = await this.chatCompletion({
payload,
clientOptions: opts,
onProgress: opts.onProgress,
abortController: opts.abortController,
});
} else { } else {
result = await this.getCompletion( result = await this.getCompletion(
payload, payload,
@ -669,6 +684,135 @@ ${convo}
content: response.text, content: response.text,
}); });
} }
async chatCompletion({ payload, onProgress, clientOptions, abortController = null }) {
let error = null;
const errorCallback = (err) => (error = err);
let intermediateReply = '';
try {
if (!abortController) {
abortController = new AbortController();
}
const modelOptions = { ...this.modelOptions };
if (typeof onProgress === 'function') {
modelOptions.stream = true;
}
if (this.isChatGptModel) {
modelOptions.messages = payload;
} else {
modelOptions.prompt = payload;
}
const { debug } = this.options;
const url = extractBaseURL(this.completionsUrl);
if (debug) {
console.debug('baseURL', url);
console.debug('modelOptions', modelOptions);
}
const opts = {
baseURL: url,
};
if (this.useOpenRouter) {
opts.defaultHeaders = {
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
};
}
if (this.options.headers) {
opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers };
}
if (this.options.proxy) {
opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
let chatCompletion;
const openai = new OpenAI({
apiKey: this.apiKey,
...opts,
});
if (modelOptions.stream) {
const stream = await openai.beta.chat.completions
.stream({
...modelOptions,
stream: true,
})
.on('abort', () => {
/* Do nothing here */
})
.on('error', (err) => {
handleOpenAIErrors(err, errorCallback, 'stream');
});
for await (const chunk of stream) {
const token = chunk.choices[0]?.delta?.content || '';
intermediateReply += token;
onProgress(token);
if (abortController.signal.aborted) {
stream.controller.abort();
break;
}
}
chatCompletion = await stream.finalChatCompletion().catch((err) => {
handleOpenAIErrors(err, errorCallback, 'finalChatCompletion');
});
}
// regular completion
else {
chatCompletion = await openai.chat.completions
.create({
...modelOptions,
})
.catch((err) => {
handleOpenAIErrors(err, errorCallback, 'create');
});
}
if (!chatCompletion && error) {
throw new Error(error);
} else if (!chatCompletion) {
throw new Error('Chat completion failed');
}
const { message, finish_reason } = chatCompletion.choices[0];
if (chatCompletion && typeof clientOptions.addMetadata === 'function') {
clientOptions.addMetadata({ finish_reason });
}
return message.content;
} catch (err) {
if (
err?.message?.includes('abort') ||
(err instanceof OpenAI.APIError && err?.message?.includes('abort'))
) {
return '';
}
if (
err?.message?.includes('missing finish_reason') ||
(err instanceof OpenAI.OpenAIError && err?.message?.includes('missing finish_reason'))
) {
await abortController.abortCompletion();
return intermediateReply;
} else if (err instanceof OpenAI.APIError) {
console.log(err.name);
console.log(err.status);
console.log(err.headers);
if (intermediateReply) {
return intermediateReply;
} else {
throw err;
}
} else {
console.warn('[OpenAIClient.chatCompletion] Unhandled error type');
console.error(err);
throw err;
}
}
}
} }
module.exports = OpenAIClient; module.exports = OpenAIClient;

View file

@ -6,6 +6,7 @@ const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_pars
const checkBalance = require('../../models/checkBalance'); const checkBalance = require('../../models/checkBalance');
const { formatLangChainMessages } = require('./prompts'); const { formatLangChainMessages } = require('./prompts');
const { isEnabled } = require('../../server/utils'); const { isEnabled } = require('../../server/utils');
const { extractBaseURL } = require('../../utils');
const { SelfReflectionTool } = require('./tools'); const { SelfReflectionTool } = require('./tools');
const { loadTools } = require('./tools/util'); const { loadTools } = require('./tools/util');
@ -34,7 +35,7 @@ class PluginsClient extends OpenAIClient {
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3'); this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
if (this.options.reverseProxyUrl) { if (this.options.reverseProxyUrl) {
this.langchainProxy = this.options.reverseProxyUrl.match(/.*v1/)?.[0]; this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
!this.langchainProxy && !this.langchainProxy &&
console.warn(`The reverse proxy URL ${this.options.reverseProxyUrl} is not valid for Plugins. console.warn(`The reverse proxy URL ${this.options.reverseProxyUrl} is not valid for Plugins.
The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions

View file

@ -86,7 +86,7 @@ describe('OpenAIClient', () => {
client.setOptions({ reverseProxyUrl: 'https://example.com/completions' }); client.setOptions({ reverseProxyUrl: 'https://example.com/completions' });
expect(client.completionsUrl).toBe('https://example.com/completions'); expect(client.completionsUrl).toBe('https://example.com/completions');
expect(client.langchainProxy).toBeUndefined(); expect(client.langchainProxy).toBe(null);
}); });
}); });

View file

@ -0,0 +1,30 @@
const OpenAI = require('openai');
/**
* Handles errors that may occur when making requests to OpenAI's API.
* It checks the instance of the error and prints a specific warning message
* to the console depending on the type of error encountered.
* It then calls an optional error callback function with the error object.
*
* @param {Error} err - The error object thrown by OpenAI API.
* @param {Function} errorCallback - A callback function that is called with the error object.
* @param {string} [context='stream'] - A string providing context where the error occurred, defaults to 'stream'.
*/
async function handleOpenAIErrors(err, errorCallback, context = 'stream') {
if (err instanceof OpenAI.APIError && err?.message?.includes('abort')) {
console.warn(`[OpenAIClient.chatCompletion][${context}] Aborted Message`);
}
if (err instanceof OpenAI.OpenAIError && err?.message?.includes('missing finish_reason')) {
console.warn(`[OpenAIClient.chatCompletion][${context}] Missing finish_reason`);
} else if (err instanceof OpenAI.APIError) {
console.warn(`[OpenAIClient.chatCompletion][${context}] API Error`);
} else {
console.warn(`[OpenAIClient.chatCompletion][${context}] Unhandled error type`);
}
if (errorCallback) {
errorCallback(err);
}
}
module.exports = handleOpenAIErrors;

View file

@ -1,6 +1,8 @@
const { validateTools, loadTools } = require('./handleTools'); const { validateTools, loadTools } = require('./handleTools');
const handleOpenAIErrors = require('./handleOpenAIErrors');
module.exports = { module.exports = {
handleOpenAIErrors,
validateTools, validateTools,
loadTools, loadTools,
}; };

View file

@ -1,6 +1,7 @@
const Keyv = require('keyv'); const Keyv = require('keyv');
const axios = require('axios'); const axios = require('axios');
const { isEnabled } = require('../utils'); const { isEnabled } = require('../utils');
const { extractBaseURL } = require('../../utils');
const keyvRedis = require('../../cache/keyvRedis'); const keyvRedis = require('../../cache/keyvRedis');
// const { getAzureCredentials, genAzureChatCompletion } = require('../../utils/'); // const { getAzureCredentials, genAzureChatCompletion } = require('../../utils/');
const { openAIApiKey, userProvidedOpenAI } = require('./EndpointService').config; const { openAIApiKey, userProvidedOpenAI } = require('./EndpointService').config;
@ -30,7 +31,7 @@ const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _model
} }
if (reverseProxyUrl) { if (reverseProxyUrl) {
basePath = reverseProxyUrl.match(/.*v1/)?.[0]; basePath = extractBaseURL(reverseProxyUrl);
} }
const cachedModels = await modelsCache.get(basePath); const cachedModels = await modelsCache.get(basePath);

View file

@ -0,0 +1,48 @@
/**
* Extracts a valid OpenAI baseURL from a given string, matching "url/v1," also an added suffix,
* ending with "/openai" (to allow the Cloudflare, LiteLLM pattern).
*
* Examples:
* - `https://open.ai/v1/chat` -> `https://open.ai/v1`
* - `https://open.ai/v1/chat/completions` -> `https://open.ai/v1`
* - `https://open.ai/v1/ACCOUNT/GATEWAY/openai/completions` -> `https://open.ai/v1/ACCOUNT/GATEWAY/openai`
* - `https://open.ai/v1/hi/openai` -> `https://open.ai/v1/hi/openai`
*
* @param {string} url - The URL to be processed.
* @returns {string|null} The matched pattern or null if no match is found.
*/
function extractBaseURL(url) {
// First, let's make sure the URL contains '/v1'.
if (!url.includes('/v1')) {
return null;
}
// Find the index of '/v1' to use it as a reference point.
const v1Index = url.indexOf('/v1');
// Extract the part of the URL up to and including '/v1'.
let baseUrl = url.substring(0, v1Index + 3);
// Check if the URL has '/openai' immediately after '/v1'.
const openaiIndex = url.indexOf('/openai', v1Index + 3);
// If '/openai' is found right after '/v1', include it in the base URL.
if (openaiIndex === v1Index + 3) {
// Find the next slash or the end of the URL after '/openai'.
const nextSlashIndex = url.indexOf('/', openaiIndex + 7);
if (nextSlashIndex === -1) {
// If there is no next slash, the rest of the URL is the base URL.
baseUrl = url.substring(0, openaiIndex + 7);
} else {
// If there is a next slash, the base URL goes up to but not including the slash.
baseUrl = url.substring(0, nextSlashIndex);
}
} else if (openaiIndex > 0) {
// If '/openai' is present but not immediately after '/v1', we need to include the reverse proxy pattern.
baseUrl = url.substring(0, openaiIndex + 7);
}
return baseUrl;
}
module.exports = extractBaseURL; // Export the function for use in your test file.

View file

@ -0,0 +1,56 @@
const extractBaseURL = require('./extractBaseURL');
describe('extractBaseURL', () => {
test('should extract base URL up to /v1 for standard endpoints', () => {
const url = 'https://localhost:8080/v1/chat/completions';
expect(extractBaseURL(url)).toBe('https://localhost:8080/v1');
});
test('should include /openai in the extracted URL when present', () => {
const url = 'https://localhost:8080/v1/openai';
expect(extractBaseURL(url)).toBe('https://localhost:8080/v1/openai');
});
test('should stop at /openai and not include any additional paths', () => {
const url = 'https://fake.open.ai/v1/openai/you-are-cool';
expect(extractBaseURL(url)).toBe('https://fake.open.ai/v1/openai');
});
test('should return the correct base URL for official openai endpoints', () => {
const url = 'https://api.openai.com/v1/chat/completions';
expect(extractBaseURL(url)).toBe('https://api.openai.com/v1');
});
test('should handle URLs with reverse proxy pattern correctly', () => {
const url = 'https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai/completions';
expect(extractBaseURL(url)).toBe(
'https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai',
);
});
test('should return null if the URL does not match the expected pattern', () => {
const url = 'https://someotherdomain.com/notv1';
expect(extractBaseURL(url)).toBeNull();
});
// Test our JSDoc examples.
test('should extract base URL up to /v1 for open.ai standard endpoint', () => {
const url = 'https://open.ai/v1/chat';
expect(extractBaseURL(url)).toBe('https://open.ai/v1');
});
test('should extract base URL up to /v1 for open.ai standard endpoint with additional path', () => {
const url = 'https://open.ai/v1/chat/completions';
expect(extractBaseURL(url)).toBe('https://open.ai/v1');
});
test('should handle URLs with ACCOUNT/GATEWAY pattern followed by /openai', () => {
const url = 'https://open.ai/v1/ACCOUNT/GATEWAY/openai/completions';
expect(extractBaseURL(url)).toBe('https://open.ai/v1/ACCOUNT/GATEWAY/openai');
});
test('should include /openai in the extracted URL with additional segments', () => {
const url = 'https://open.ai/v1/hi/openai';
expect(extractBaseURL(url)).toBe('https://open.ai/v1/hi/openai');
});
});

View file

@ -1,9 +1,11 @@
const azureUtils = require('./azureUtils');
const tokenHelpers = require('./tokens'); const tokenHelpers = require('./tokens');
const azureUtils = require('./azureUtils');
const extractBaseURL = require('./extractBaseURL');
const findMessageContent = require('./findMessageContent'); const findMessageContent = require('./findMessageContent');
module.exports = { module.exports = {
...azureUtils, ...azureUtils,
...tokenHelpers, ...tokenHelpers,
extractBaseURL,
findMessageContent, findMessageContent,
}; };