🧠 feat: Cohere support as Custom Endpoint (#2328)

* chore: bump cohere-ai, fix firebase vulnerabilities by going down versions

* feat: cohere rates and context windows

* feat(createCoherePayload): transform openai payload for cohere compatibility

* feat: cohere backend support

* refactor(UnknownIcon): optimize icon render and add cohere

* docs: add cohere to Compatible AI Endpoints

* Update ai_endpoints.md
This commit is contained in:
Danny Avila 2024-04-05 15:19:41 -04:00 committed by GitHub
parent daa5f43ac6
commit cd7f3a51e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1007 additions and 622 deletions

View file

@ -5,6 +5,7 @@ const {
EModelEndpoint,
resolveHeaders,
ImageDetailCost,
CohereConstants,
getResponseSender,
validateVisionModel,
mapModelToAzureConfig,
@ -16,7 +17,13 @@ const {
getModelMaxTokens,
genAzureChatCompletion,
} = require('~/utils');
const { truncateText, formatMessage, createContextHandlers, CUT_OFF_PROMPT } = require('./prompts');
const {
truncateText,
formatMessage,
createContextHandlers,
CUT_OFF_PROMPT,
titleInstruction,
} = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { handleOpenAIErrors } = require('./tools/util');
const spendTokens = require('~/models/spendTokens');
@ -39,7 +46,10 @@ class OpenAIClient extends BaseClient {
super(apiKey, options);
this.ChatGPTClient = new ChatGPTClient();
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
/** @type {getCompletion} */
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
/** @type {cohereChatCompletion} */
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase()
: 'discard';
@ -48,6 +58,9 @@ class OpenAIClient extends BaseClient {
this.azure = options.azure || false;
this.setOptions(options);
this.metadata = {};
/** @type {string | undefined} - The API Completions URL */
this.completionsUrl;
}
// TODO: PluginsClient calls this 3x, unneeded
@ -533,6 +546,7 @@ class OpenAIClient extends BaseClient {
return result;
}
/** @type {sendCompletion} */
async sendCompletion(payload, opts = {}) {
let reply = '';
let result = null;
@ -541,7 +555,7 @@ class OpenAIClient extends BaseClient {
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
if (typeof opts.onProgress === 'function' && useOldMethod) {
await this.getCompletion(
const completionResult = await this.getCompletion(
payload,
(progressMessage) => {
if (progressMessage === '[DONE]') {
@ -574,8 +588,13 @@ class OpenAIClient extends BaseClient {
opts.onProgress(token);
reply += token;
},
opts.onProgress,
opts.abortController || new AbortController(),
);
if (completionResult && typeof completionResult === 'string') {
reply = completionResult;
}
} else if (typeof opts.onProgress === 'function' || this.options.useChatCompletion) {
reply = await this.chatCompletion({
payload,
@ -586,9 +605,14 @@ class OpenAIClient extends BaseClient {
result = await this.getCompletion(
payload,
null,
opts.onProgress,
opts.abortController || new AbortController(),
);
if (result && typeof result === 'string') {
return result.trim();
}
logger.debug('[OpenAIClient] sendCompletion: result', result);
if (this.isChatCompletion) {
@ -760,8 +784,7 @@ class OpenAIClient extends BaseClient {
const instructionsPayload = [
{
role: 'system',
content: `Detect user language and write in the same language an extremely concise title for this conversation, which you must accurately detect.
Write in the detected language. Title in 5 Words or Less. No Punctuation or Quotation. Do not mention the language. All first letters of every word should be capitalized and write the title in User Language only.
content: `Please generate ${titleInstruction}
${convo}
@ -770,8 +793,12 @@ ${convo}
];
try {
let useChatCompletion = true;
if (CohereConstants.API_URL) {
useChatCompletion = false;
}
title = (
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion: true })
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion })
).replaceAll('"', '');
} catch (e) {
logger.error(