🔧 Fix: Resolve Anthropic Client Issues 🧠 (#1226)

* fix: correct preset title for Anthropic endpoint

* fix(Settings/Anthropic): show correct default value for LLM temperature

* fix(AnthropicClient): use `getModelMaxTokens` to get the correct LLM max context tokens, correctly set default temperature to 1, use only 2 params for class constructor, use `getResponseSender` to add correct sender to response message

* refactor(/api/ask|edit/anthropic): save messages to database after the final response is sent to the client, and do not save conversation from route controller

* fix(initializeClient/anthropic): correctly pass client options (endpointOption) to class initialization

* feat(ModelService/Anthropic): add claude-1.2
This commit is contained in:
Danny Avila 2023-11-26 14:44:57 -05:00 committed by GitHub
parent 4b289640f2
commit d7ef4590ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 73 additions and 45 deletions

View file

@ -1,7 +1,8 @@
// const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const Anthropic = require('@anthropic-ai/sdk'); const Anthropic = require('@anthropic-ai/sdk');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { getResponseSender, EModelEndpoint } = require('~/server/routes/endpoints/schemas');
const { getModelMaxTokens } = require('~/utils');
const BaseClient = require('./BaseClient');
const HUMAN_PROMPT = '\n\nHuman:'; const HUMAN_PROMPT = '\n\nHuman:';
const AI_PROMPT = '\n\nAssistant:'; const AI_PROMPT = '\n\nAssistant:';
@ -9,13 +10,9 @@ const AI_PROMPT = '\n\nAssistant:';
const tokenizersCache = {}; const tokenizersCache = {};
class AnthropicClient extends BaseClient { class AnthropicClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}, baseURL) { constructor(apiKey, options = {}) {
super(apiKey, options, cacheOptions); super(apiKey, options);
this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY; this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY;
this.sender = 'Anthropic';
if (baseURL) {
this.baseURL = baseURL;
}
this.userLabel = HUMAN_PROMPT; this.userLabel = HUMAN_PROMPT;
this.assistantLabel = AI_PROMPT; this.assistantLabel = AI_PROMPT;
this.setOptions(options); this.setOptions(options);
@ -43,13 +40,13 @@ class AnthropicClient extends BaseClient {
...modelOptions, ...modelOptions,
// set some good defaults (check for undefined in some cases because they may be 0) // set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || 'claude-1', model: modelOptions.model || 'claude-1',
temperature: typeof modelOptions.temperature === 'undefined' ? 0.7 : modelOptions.temperature, // 0 - 1, 0.7 is recommended temperature: typeof modelOptions.temperature === 'undefined' ? 1 : modelOptions.temperature, // 0 - 1, 1 is default
topP: typeof modelOptions.topP === 'undefined' ? 0.7 : modelOptions.topP, // 0 - 1, default: 0.7 topP: typeof modelOptions.topP === 'undefined' ? 0.7 : modelOptions.topP, // 0 - 1, default: 0.7
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40 topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
stop: modelOptions.stop, // no stop method for now stop: modelOptions.stop, // no stop method for now
}; };
this.maxContextTokens = this.options.maxContextTokens || 99999; this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 100000;
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1500; this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1500;
this.maxPromptTokens = this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
@ -62,6 +59,14 @@ class AnthropicClient extends BaseClient {
); );
} }
this.sender =
this.options.sender ??
getResponseSender({
model: this.modelOptions.model,
endpoint: EModelEndpoint.anthropic,
modelLabel: this.options.modelLabel,
});
this.startToken = '||>'; this.startToken = '||>';
this.endToken = ''; this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base'); this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
@ -81,16 +86,15 @@ class AnthropicClient extends BaseClient {
} }
getClient() { getClient() {
if (this.baseURL) { const options = {
return new Anthropic({ apiKey: this.apiKey,
apiKey: this.apiKey, };
baseURL: this.baseURL,
}); if (this.options.reverseProxyUrl) {
} else { options.baseURL = this.options.reverseProxyUrl;
return new Anthropic({
apiKey: this.apiKey,
});
} }
return new Anthropic(options);
} }
async buildMessages(messages, parentMessageId) { async buildMessages(messages, parentMessageId) {

View file

@ -9,9 +9,9 @@ const {
setHeaders, setHeaders,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
} = require('../../middleware'); } = require('~/server/middleware');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('../../utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
@ -109,14 +109,6 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
response.parentMessageId = overrideParentMessageId; response.parentMessageId = overrideParentMessageId;
} }
await saveConvo(user, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'anthropic',
});
await saveMessage({ ...response, user });
sendMessage(res, { sendMessage(res, {
title: await getConvoTitle(user, conversationId), title: await getConvoTitle(user, conversationId),
final: true, final: true,
@ -126,6 +118,9 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
}); });
res.end(); res.end();
await saveMessage({ ...response, user });
await saveMessage(userMessage);
// TODO: add anthropic titling // TODO: add anthropic titling
} catch (error) { } catch (error) {
const partialText = getPartialText(); const partialText = getPartialText();

View file

@ -9,9 +9,9 @@ const {
setHeaders, setHeaders,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
} = require('../../middleware'); } = require('~/server/middleware');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('../../utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
@ -119,7 +119,6 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
response.parentMessageId = overrideParentMessageId; response.parentMessageId = overrideParentMessageId;
} }
await saveMessage({ ...response, user });
sendMessage(res, { sendMessage(res, {
title: await getConvoTitle(user, conversationId), title: await getConvoTitle(user, conversationId),
final: true, final: true,
@ -129,6 +128,9 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
}); });
res.end(); res.end();
await saveMessage({ ...response, user });
await saveMessage(userMessage);
// TODO: add anthropic titling // TODO: add anthropic titling
} catch (error) { } catch (error) {
const partialText = getPartialText(); const partialText = getPartialText();

View file

@ -1,14 +1,14 @@
const { AnthropicClient } = require('../../../../app'); const { AnthropicClient } = require('~/app');
const { getUserKey, checkUserKeyExpiry } = require('../../../services/UserService'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const initializeClient = async ({ req, res }) => { const initializeClient = async ({ req, res, endpointOption }) => {
const ANTHROPIC_API_KEY = process.env.ANTHROPIC_API_KEY; const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY } = process.env;
const expiresAt = req.body.key; const expiresAt = req.body.key;
const isUserProvided = ANTHROPIC_API_KEY === 'user_provided'; const isUserProvided = ANTHROPIC_API_KEY === 'user_provided';
let anthropicApiKey = isUserProvided ? await getAnthropicUserKey(req.user.id) : ANTHROPIC_API_KEY; const anthropicApiKey = isUserProvided
let reverseProxy = process.env.ANTHROPIC_REVERSE_PROXY || undefined; ? await getAnthropicUserKey(req.user.id)
console.log('ANTHROPIC_REVERSE_PROXY', reverseProxy); : ANTHROPIC_API_KEY;
if (expiresAt && isUserProvided) { if (expiresAt && isUserProvided) {
checkUserKeyExpiry( checkUserKeyExpiry(
@ -17,7 +17,12 @@ const initializeClient = async ({ req, res }) => {
); );
} }
const client = new AnthropicClient(anthropicApiKey, { req, res }, {}, reverseProxy); const client = new AnthropicClient(anthropicApiKey, {
req,
res,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
...endpointOption,
});
return { return {
client, client,

View file

@ -117,6 +117,7 @@ const getAnthropicModels = () => {
let models = [ let models = [
'claude-2.1', 'claude-2.1',
'claude-2', 'claude-2',
'claude-1.2',
'claude-1', 'claude-1',
'claude-1-100k', 'claude-1-100k',
'claude-instant-1', 'claude-instant-1',

View file

@ -51,6 +51,8 @@ const maxTokensMap = {
'gpt-3.5-turbo-16k-0613': 15999, 'gpt-3.5-turbo-16k-0613': 15999,
'gpt-3.5-turbo-1106': 16380, // -5 from max 'gpt-3.5-turbo-1106': 16380, // -5 from max
'gpt-4-1106': 127995, // -5 from max 'gpt-4-1106': 127995, // -5 from max
'claude-2.1': 200000,
'claude-': 100000,
}; };
/** /**

View file

@ -62,6 +62,25 @@ describe('getModelMaxTokens', () => {
expect(getModelMaxTokens('gpt-4-1106-preview')).toBe(maxTokensMap['gpt-4-1106']); expect(getModelMaxTokens('gpt-4-1106-preview')).toBe(maxTokensMap['gpt-4-1106']);
expect(getModelMaxTokens('gpt-4-1106-vision-preview')).toBe(maxTokensMap['gpt-4-1106']); expect(getModelMaxTokens('gpt-4-1106-vision-preview')).toBe(maxTokensMap['gpt-4-1106']);
}); });
test('should return correct tokens for Anthropic models', () => {
const models = [
'claude-2.1',
'claude-2',
'claude-1.2',
'claude-1',
'claude-1-100k',
'claude-instant-1',
'claude-instant-1-100k',
];
const claude21MaxTokens = maxTokensMap['claude-2.1'];
const claudeMaxTokens = maxTokensMap['claude-'];
models.forEach((model) => {
const expectedTokens = model === 'claude-2.1' ? claude21MaxTokens : claudeMaxTokens;
expect(getModelMaxTokens(model)).toEqual(expectedTokens);
});
});
}); });
describe('matchModelName', () => { describe('matchModelName', () => {

View file

@ -86,7 +86,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
<div className="flex justify-between"> <div className="flex justify-between">
<Label htmlFor="temp-int" className="text-left text-sm font-medium"> <Label htmlFor="temp-int" className="text-left text-sm font-medium">
{localize('com_endpoint_temperature')}{' '} {localize('com_endpoint_temperature')}{' '}
<small className="opacity-40">({localize('com_endpoint_default')}: 0.2)</small> <small className="opacity-40">({localize('com_endpoint_default')}: 1)</small>
</Label> </Label>
<InputNumber <InputNumber
id="temp-int" id="temp-int"

View file

@ -24,7 +24,7 @@ export const getPresetTitle = (preset: TPreset) => {
if (model) { if (model) {
_title += `: ${model}`; _title += `: ${model}`;
} }
} else if (endpoint === EModelEndpoint.google) { } else if (endpoint === EModelEndpoint.google || endpoint === EModelEndpoint.anthropic) {
if (modelLabel) { if (modelLabel) {
_title = modelLabel; _title = modelLabel;
} }