fix(OpenAIClient/PluginsClient): allow non-v1 reverse proxy, handle "v1/completions" reverse proxy (#1029)

* fix(OpenAIClient): handle completions request in reverse proxy, also force prompt by env var

* fix(reverseProxyUrl): allow url without /v1/ but add server warning as it will not be compatible with plugins

* fix(ModelService): handle reverse proxy without v1

* refactor: make changes cleaner

* ci(OpenAIClient): add tests for OPENROUTER_API_KEY, FORCE_PROMPT, and reverseProxyUrl handling in setOptions
This commit is contained in:
Danny Avila 2023-10-08 16:57:25 -04:00 committed by GitHub
parent d61e44742d
commit 2dd545eaa4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 14 deletions

View file

@ -117,6 +117,12 @@ DEBUG_OPENAI=false # Set to true to enable debug mode for the OpenAI endpoint
# https://github.com/waylaidwanderer/node-chatgpt-api#using-a-reverse-proxy # https://github.com/waylaidwanderer/node-chatgpt-api#using-a-reverse-proxy
# OPENAI_REVERSE_PROXY= # OPENAI_REVERSE_PROXY=
# (Advanced) Sometimes when using Local LLM APIs, you may need to force the API
# to be called with a `prompt` payload instead of a `messages` payload; to mimic the
# a `/v1/completions` request instead of `/v1/chat/completions`
# This may be the case for LocalAI with some models. To do so, uncomment the following:
# OPENAI_FORCE_PROMPT=true
########################## ##########################
# OpenRouter (overrides OpenAI and Plugins Endpoints): # OpenRouter (overrides OpenAI and Plugins Endpoints):
########################## ##########################

View file

@ -4,6 +4,7 @@ const BaseClient = require('./BaseClient');
const { getModelMaxTokens, genAzureChatCompletion } = require('../../utils'); const { getModelMaxTokens, genAzureChatCompletion } = require('../../utils');
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts'); const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
const spendTokens = require('../../models/spendTokens'); const spendTokens = require('../../models/spendTokens');
const { isEnabled } = require('../../server/utils');
const { createLLM, RunManager } = require('./llm'); const { createLLM, RunManager } = require('./llm');
const { summaryBuffer } = require('./memory'); const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains'); const { runTitleChain } = require('./chains');
@ -71,20 +72,22 @@ class OpenAIClient extends BaseClient {
}; };
} }
if (process.env.OPENROUTER_API_KEY) { const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
this.apiKey = process.env.OPENROUTER_API_KEY; if (OPENROUTER_API_KEY) {
this.apiKey = OPENROUTER_API_KEY;
this.useOpenRouter = true; this.useOpenRouter = true;
} }
const { reverseProxyUrl: reverseProxy } = this.options;
this.FORCE_PROMPT =
isEnabled(OPENAI_FORCE_PROMPT) ||
(reverseProxy && reverseProxy.includes('completions') && !reverseProxy.includes('chat'));
const { model } = this.modelOptions; const { model } = this.modelOptions;
this.isChatCompletion = this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt-');
this.useOpenRouter ||
this.options.reverseProxyUrl ||
this.options.localAI ||
model.includes('gpt-');
this.isChatGptModel = this.isChatCompletion; this.isChatGptModel = this.isChatCompletion;
if (model.includes('text-davinci-003') || model.includes('instruct')) { if (model.includes('text-davinci-003') || model.includes('instruct') || this.FORCE_PROMPT) {
this.isChatCompletion = false; this.isChatCompletion = false;
this.isChatGptModel = false; this.isChatGptModel = false;
} }
@ -128,9 +131,13 @@ class OpenAIClient extends BaseClient {
this.modelOptions.stop = stopTokens; this.modelOptions.stop = stopTokens;
} }
if (this.options.reverseProxyUrl) { if (reverseProxy) {
this.completionsUrl = this.options.reverseProxyUrl; this.completionsUrl = reverseProxy;
this.langchainProxy = this.options.reverseProxyUrl.match(/.*v1/)[0]; this.langchainProxy = reverseProxy.match(/.*v1/)?.[0];
!this.langchainProxy &&
console.warn(`The reverse proxy URL ${reverseProxy} is not valid for Plugins.
The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions
If your reverse proxy is compatible to OpenAI specs in every other way, it may still work without plugins enabled.`);
} else if (isChatGptModel) { } else if (isChatGptModel) {
this.completionsUrl = 'https://api.openai.com/v1/chat/completions'; this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
} else { } else {
@ -185,7 +192,7 @@ class OpenAIClient extends BaseClient {
this.encoding = model.includes('instruct') ? 'text-davinci-003' : model; this.encoding = model.includes('instruct') ? 'text-davinci-003' : model;
tokenizer = this.constructor.getTokenizer(this.encoding, true); tokenizer = this.constructor.getTokenizer(this.encoding, true);
} catch { } catch {
tokenizer = this.constructor.getTokenizer(this.encoding, true); tokenizer = this.constructor.getTokenizer('text-davinci-003', true);
} }
} }

View file

@ -34,7 +34,11 @@ class PluginsClient extends OpenAIClient {
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3'); this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
if (this.options.reverseProxyUrl) { if (this.options.reverseProxyUrl) {
this.langchainProxy = this.options.reverseProxyUrl.match(/.*v1/)[0]; this.langchainProxy = this.options.reverseProxyUrl.match(/.*v1/)?.[0];
!this.langchainProxy &&
console.warn(`The reverse proxy URL ${this.options.reverseProxyUrl} is not valid for Plugins.
The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions
If your reverse proxy is compatible to OpenAI specs in every other way, it may still work without plugins enabled.`);
} }
} }

View file

@ -1,3 +1,4 @@
require('dotenv').config();
const OpenAIClient = require('../OpenAIClient'); const OpenAIClient = require('../OpenAIClient');
jest.mock('meilisearch'); jest.mock('meilisearch');
@ -39,6 +40,54 @@ describe('OpenAIClient', () => {
expect(client.modelOptions.model).toBe(model); expect(client.modelOptions.model).toBe(model);
expect(client.modelOptions.temperature).toBe(0.7); expect(client.modelOptions.temperature).toBe(0.7);
}); });
it('should set apiKey and useOpenRouter if OPENROUTER_API_KEY is present', () => {
process.env.OPENROUTER_API_KEY = 'openrouter-key';
client.setOptions({});
expect(client.apiKey).toBe('openrouter-key');
expect(client.useOpenRouter).toBe(true);
delete process.env.OPENROUTER_API_KEY; // Cleanup
});
it('should set FORCE_PROMPT based on OPENAI_FORCE_PROMPT or reverseProxyUrl', () => {
process.env.OPENAI_FORCE_PROMPT = 'true';
client.setOptions({});
expect(client.FORCE_PROMPT).toBe(true);
delete process.env.OPENAI_FORCE_PROMPT; // Cleanup
client.FORCE_PROMPT = undefined;
client.setOptions({ reverseProxyUrl: 'https://example.com/completions' });
expect(client.FORCE_PROMPT).toBe(true);
client.FORCE_PROMPT = undefined;
client.setOptions({ reverseProxyUrl: 'https://example.com/chat' });
expect(client.FORCE_PROMPT).toBe(false);
});
it('should set isChatCompletion based on useOpenRouter, reverseProxyUrl, or model', () => {
client.setOptions({ reverseProxyUrl: null });
// true by default since default model will be gpt-3.5-turbo
expect(client.isChatCompletion).toBe(true);
client.isChatCompletion = undefined;
// false because completions url will force prompt payload
client.setOptions({ reverseProxyUrl: 'https://example.com/completions' });
expect(client.isChatCompletion).toBe(false);
client.isChatCompletion = undefined;
client.setOptions({ modelOptions: { model: 'gpt-3.5-turbo' }, reverseProxyUrl: null });
expect(client.isChatCompletion).toBe(true);
});
it('should set completionsUrl and langchainProxy based on reverseProxyUrl', () => {
client.setOptions({ reverseProxyUrl: 'https://localhost:8080/v1/chat/completions' });
expect(client.completionsUrl).toBe('https://localhost:8080/v1/chat/completions');
expect(client.langchainProxy).toBe('https://localhost:8080/v1');
client.setOptions({ reverseProxyUrl: 'https://example.com/completions' });
expect(client.completionsUrl).toBe('https://example.com/completions');
expect(client.langchainProxy).toBeUndefined();
});
}); });
describe('selectTokenizer', () => { describe('selectTokenizer', () => {

View file

@ -28,7 +28,7 @@ const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _model
} }
if (reverseProxyUrl) { if (reverseProxyUrl) {
basePath = reverseProxyUrl.match(/.*v1/)[0]; basePath = reverseProxyUrl.match(/.*v1/)?.[0];
} }
const cachedModels = await modelsCache.get(basePath); const cachedModels = await modelsCache.get(basePath);