mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-01 08:08:49 +01:00
🧠 feat: Cohere support as Custom Endpoint (#2328)
* chore: bump cohere-ai, fix firebase vulnerabilities by going down versions * feat: cohere rates and context windows * feat(createCoherePayload): transform openai payload for cohere compatibility * feat: cohere backend support * refactor(UnknownIcon): optimize icon render and add cohere * docs: add cohere to Compatible AI Endpoints * Update ai_endpoints.md
This commit is contained in:
parent
daa5f43ac6
commit
cd7f3a51e1
18 changed files with 1007 additions and 622 deletions
|
|
@ -23,7 +23,7 @@ class BaseClient {
|
|||
throw new Error('Method \'setOptions\' must be implemented.');
|
||||
}
|
||||
|
||||
getCompletion() {
|
||||
async getCompletion() {
|
||||
throw new Error('Method \'getCompletion\' must be implemented.');
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,13 @@ const crypto = require('crypto');
|
|||
const {
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
CohereConstants,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { CohereClient } = require('cohere-ai');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
||||
const { createCoherePayload } = require('./llm');
|
||||
const { Agent, ProxyAgent } = require('undici');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
|
@ -147,7 +150,8 @@ class ChatGPTClient extends BaseClient {
|
|||
return tokenizer;
|
||||
}
|
||||
|
||||
async getCompletion(input, onProgress, abortController = null) {
|
||||
/** @type {getCompletion} */
|
||||
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
|
|
@ -305,6 +309,11 @@ class ChatGPTClient extends BaseClient {
|
|||
});
|
||||
}
|
||||
|
||||
if (baseURL.startsWith(CohereConstants.API_URL)) {
|
||||
const payload = createCoherePayload({ modelOptions });
|
||||
return await this.cohereChatCompletion({ payload, onTokenProgress });
|
||||
}
|
||||
|
||||
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
|
||||
baseURL = baseURL.split('v1')[0] + 'v1/completions';
|
||||
} else if (
|
||||
|
|
@ -408,6 +417,35 @@ class ChatGPTClient extends BaseClient {
|
|||
return response.json();
|
||||
}
|
||||
|
||||
/** @type {cohereChatCompletion} */
|
||||
async cohereChatCompletion({ payload, onTokenProgress }) {
|
||||
const cohere = new CohereClient({
|
||||
token: this.apiKey,
|
||||
environment: this.completionsUrl,
|
||||
});
|
||||
|
||||
if (!payload.stream) {
|
||||
const chatResponse = await cohere.chat(payload);
|
||||
return chatResponse.text;
|
||||
}
|
||||
|
||||
const chatStream = await cohere.chatStream(payload);
|
||||
let reply = '';
|
||||
for await (const message of chatStream) {
|
||||
if (!message) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (message.eventType === 'text-generation' && message.text) {
|
||||
onTokenProgress(message.text);
|
||||
} else if (message.eventType === 'stream-end' && message.response) {
|
||||
reply = message.response.text;
|
||||
}
|
||||
}
|
||||
|
||||
return reply;
|
||||
}
|
||||
|
||||
async generateTitle(userMessage, botMessage) {
|
||||
const instructionsPayload = {
|
||||
role: 'system',
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ const {
|
|||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
ImageDetailCost,
|
||||
CohereConstants,
|
||||
getResponseSender,
|
||||
validateVisionModel,
|
||||
mapModelToAzureConfig,
|
||||
|
|
@ -16,7 +17,13 @@ const {
|
|||
getModelMaxTokens,
|
||||
genAzureChatCompletion,
|
||||
} = require('~/utils');
|
||||
const { truncateText, formatMessage, createContextHandlers, CUT_OFF_PROMPT } = require('./prompts');
|
||||
const {
|
||||
truncateText,
|
||||
formatMessage,
|
||||
createContextHandlers,
|
||||
CUT_OFF_PROMPT,
|
||||
titleInstruction,
|
||||
} = require('./prompts');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
|
|
@ -39,7 +46,10 @@ class OpenAIClient extends BaseClient {
|
|||
super(apiKey, options);
|
||||
this.ChatGPTClient = new ChatGPTClient();
|
||||
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
|
||||
/** @type {getCompletion} */
|
||||
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
|
||||
/** @type {cohereChatCompletion} */
|
||||
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
|
||||
this.contextStrategy = options.contextStrategy
|
||||
? options.contextStrategy.toLowerCase()
|
||||
: 'discard';
|
||||
|
|
@ -48,6 +58,9 @@ class OpenAIClient extends BaseClient {
|
|||
this.azure = options.azure || false;
|
||||
this.setOptions(options);
|
||||
this.metadata = {};
|
||||
|
||||
/** @type {string | undefined} - The API Completions URL */
|
||||
this.completionsUrl;
|
||||
}
|
||||
|
||||
// TODO: PluginsClient calls this 3x, unneeded
|
||||
|
|
@ -533,6 +546,7 @@ class OpenAIClient extends BaseClient {
|
|||
return result;
|
||||
}
|
||||
|
||||
/** @type {sendCompletion} */
|
||||
async sendCompletion(payload, opts = {}) {
|
||||
let reply = '';
|
||||
let result = null;
|
||||
|
|
@ -541,7 +555,7 @@ class OpenAIClient extends BaseClient {
|
|||
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
|
||||
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
|
||||
if (typeof opts.onProgress === 'function' && useOldMethod) {
|
||||
await this.getCompletion(
|
||||
const completionResult = await this.getCompletion(
|
||||
payload,
|
||||
(progressMessage) => {
|
||||
if (progressMessage === '[DONE]') {
|
||||
|
|
@ -574,8 +588,13 @@ class OpenAIClient extends BaseClient {
|
|||
opts.onProgress(token);
|
||||
reply += token;
|
||||
},
|
||||
opts.onProgress,
|
||||
opts.abortController || new AbortController(),
|
||||
);
|
||||
|
||||
if (completionResult && typeof completionResult === 'string') {
|
||||
reply = completionResult;
|
||||
}
|
||||
} else if (typeof opts.onProgress === 'function' || this.options.useChatCompletion) {
|
||||
reply = await this.chatCompletion({
|
||||
payload,
|
||||
|
|
@ -586,9 +605,14 @@ class OpenAIClient extends BaseClient {
|
|||
result = await this.getCompletion(
|
||||
payload,
|
||||
null,
|
||||
opts.onProgress,
|
||||
opts.abortController || new AbortController(),
|
||||
);
|
||||
|
||||
if (result && typeof result === 'string') {
|
||||
return result.trim();
|
||||
}
|
||||
|
||||
logger.debug('[OpenAIClient] sendCompletion: result', result);
|
||||
|
||||
if (this.isChatCompletion) {
|
||||
|
|
@ -760,8 +784,7 @@ class OpenAIClient extends BaseClient {
|
|||
const instructionsPayload = [
|
||||
{
|
||||
role: 'system',
|
||||
content: `Detect user language and write in the same language an extremely concise title for this conversation, which you must accurately detect.
|
||||
Write in the detected language. Title in 5 Words or Less. No Punctuation or Quotation. Do not mention the language. All first letters of every word should be capitalized and write the title in User Language only.
|
||||
content: `Please generate ${titleInstruction}
|
||||
|
||||
${convo}
|
||||
|
||||
|
|
@ -770,8 +793,12 @@ ${convo}
|
|||
];
|
||||
|
||||
try {
|
||||
let useChatCompletion = true;
|
||||
if (CohereConstants.API_URL) {
|
||||
useChatCompletion = false;
|
||||
}
|
||||
title = (
|
||||
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion: true })
|
||||
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion })
|
||||
).replaceAll('"', '');
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
|
|
|
|||
85
api/app/clients/llm/createCoherePayload.js
Normal file
85
api/app/clients/llm/createCoherePayload.js
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
const { CohereConstants } = require('librechat-data-provider');
|
||||
const { titleInstruction } = require('../prompts/titlePrompts');
|
||||
|
||||
// Mapping OpenAI roles to Cohere roles
|
||||
const roleMap = {
|
||||
user: CohereConstants.ROLE_USER,
|
||||
assistant: CohereConstants.ROLE_CHATBOT,
|
||||
system: CohereConstants.ROLE_SYSTEM, // Recognize and map the system role explicitly
|
||||
};
|
||||
|
||||
/**
|
||||
* Adjusts an OpenAI ChatCompletionPayload to conform with Cohere's expected chat payload format.
|
||||
* Now includes handling for "system" roles explicitly mentioned.
|
||||
*
|
||||
* @param {Object} options - Object containing the model options.
|
||||
* @param {ChatCompletionPayload} options.modelOptions - The OpenAI model payload options.
|
||||
* @returns {CohereChatStreamRequest} Cohere-compatible chat API payload.
|
||||
*/
|
||||
function createCoherePayload({ modelOptions }) {
|
||||
/** @type {string | undefined} */
|
||||
let preamble;
|
||||
let latestUserMessageContent = '';
|
||||
const {
|
||||
stream,
|
||||
stop,
|
||||
top_p,
|
||||
temperature,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
max_tokens,
|
||||
messages,
|
||||
model,
|
||||
...rest
|
||||
} = modelOptions;
|
||||
|
||||
// Filter out the latest user message and transform remaining messages to Cohere's chat_history format
|
||||
let chatHistory = messages.reduce((acc, message, index, arr) => {
|
||||
const isLastUserMessage = index === arr.length - 1 && message.role === 'user';
|
||||
|
||||
const messageContent =
|
||||
typeof message.content === 'string'
|
||||
? message.content
|
||||
: message.content.map((part) => (part.type === 'text' ? part.text : '')).join(' ');
|
||||
|
||||
if (isLastUserMessage) {
|
||||
latestUserMessageContent = messageContent;
|
||||
} else {
|
||||
acc.push({
|
||||
role: roleMap[message.role] || CohereConstants.ROLE_USER,
|
||||
message: messageContent,
|
||||
});
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
if (
|
||||
chatHistory.length === 1 &&
|
||||
chatHistory[0].role === CohereConstants.ROLE_SYSTEM &&
|
||||
!latestUserMessageContent.length
|
||||
) {
|
||||
const message = chatHistory[0].message;
|
||||
latestUserMessageContent = message.includes(titleInstruction)
|
||||
? CohereConstants.TITLE_MESSAGE
|
||||
: '.';
|
||||
preamble = message;
|
||||
}
|
||||
|
||||
return {
|
||||
message: latestUserMessageContent,
|
||||
model: model,
|
||||
chat_history: chatHistory,
|
||||
stream: stream ?? false,
|
||||
temperature: temperature,
|
||||
frequency_penalty: frequency_penalty,
|
||||
presence_penalty: presence_penalty,
|
||||
max_tokens: max_tokens,
|
||||
stop_sequences: stop,
|
||||
preamble,
|
||||
p: top_p,
|
||||
...rest,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = createCoherePayload;
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
const createLLM = require('./createLLM');
|
||||
const RunManager = require('./RunManager');
|
||||
const createCoherePayload = require('./createCoherePayload');
|
||||
|
||||
module.exports = {
|
||||
createLLM,
|
||||
RunManager,
|
||||
createCoherePayload,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ ${convo}`,
|
|||
return titlePrompt;
|
||||
};
|
||||
|
||||
const titleInstruction =
|
||||
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"';
|
||||
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.
|
||||
|
||||
You may call them like this:
|
||||
|
|
@ -51,7 +53,7 @@ Submit a brief title in the conversation's language, following the parameter des
|
|||
<parameter>
|
||||
<name>title</name>
|
||||
<type>string</type>
|
||||
<description>A concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"</description>
|
||||
<description>${titleInstruction}</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</tool_description>
|
||||
|
|
@ -80,6 +82,7 @@ function parseTitleFromPrompt(prompt) {
|
|||
|
||||
module.exports = {
|
||||
langPrompt,
|
||||
titleInstruction,
|
||||
createTitlePrompt,
|
||||
titleFunctionPrompt,
|
||||
parseTitleFromPrompt,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ const defaultRate = 6;
|
|||
|
||||
/**
|
||||
* Mapping of model token sizes to their respective multipliers for prompt and completion.
|
||||
* The rates are 1 USD per 1M tokens.
|
||||
* @type {Object.<string, {prompt: number, completion: number}>}
|
||||
*/
|
||||
const tokenValues = {
|
||||
|
|
@ -19,6 +20,11 @@ const tokenValues = {
|
|||
'claude-2.1': { prompt: 8, completion: 24 },
|
||||
'claude-2': { prompt: 8, completion: 24 },
|
||||
'claude-': { prompt: 0.8, completion: 2.4 },
|
||||
'command-r-plus': { prompt: 3, completion: 15 },
|
||||
'command-r': { prompt: 0.5, completion: 1.5 },
|
||||
/* cohere doesn't have rates for the older command models,
|
||||
so this was from https://artificialanalysis.ai/models/command-light/providers */
|
||||
command: { prompt: 0.38, completion: 0.38 },
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@
|
|||
"axios": "^1.3.4",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"cheerio": "^1.0.0-rc.12",
|
||||
"cohere-ai": "^6.0.0",
|
||||
"cohere-ai": "^7.9.1",
|
||||
"connect-redis": "^7.1.0",
|
||||
"cookie": "^0.5.0",
|
||||
"cors": "^2.8.5",
|
||||
|
|
@ -52,7 +52,7 @@
|
|||
"express-rate-limit": "^6.9.0",
|
||||
"express-session": "^1.17.3",
|
||||
"file-type": "^18.7.0",
|
||||
"firebase": "^10.8.0",
|
||||
"firebase": "^10.6.0",
|
||||
"googleapis": "^126.0.1",
|
||||
"handlebars": "^4.7.7",
|
||||
"html": "^1.0.0",
|
||||
|
|
|
|||
|
|
@ -44,6 +44,30 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports ChatCompletionPayload
|
||||
* @typedef {import('openai').OpenAI.ChatCompletionCreateParams} ChatCompletionPayload
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports ChatCompletionMessages
|
||||
* @typedef {import('openai').OpenAI.ChatCompletionMessageParam} ChatCompletionMessages
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports CohereChatStreamRequest
|
||||
* @typedef {import('cohere-ai').Cohere.ChatStreamRequest} CohereChatStreamRequest
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports CohereChatRequest
|
||||
* @typedef {import('cohere-ai').Cohere.ChatRequest} CohereChatRequest
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports OpenAIRequestOptions
|
||||
* @typedef {import('openai').OpenAI.RequestOptions} OpenAIRequestOptions
|
||||
|
|
@ -1062,3 +1086,44 @@
|
|||
* @method handleMessageEvent Handles events related to messages within the run.
|
||||
* @method messageCompleted Handles the completion of a message processing.
|
||||
*/
|
||||
|
||||
/* Native app/client methods */
|
||||
|
||||
/**
|
||||
* Accumulates tokens and sends them to the client for processing.
|
||||
* @callback onTokenProgress
|
||||
* @param {string} token - The current token generated by the model.
|
||||
* @returns {Promise<void>}
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* Main entrypoint for API completion calls
|
||||
* @callback sendCompletion
|
||||
* @param {Array<ChatCompletionMessages> | string} payload - The messages or prompt to send to the model
|
||||
* @param {object} opts - Options for the completion
|
||||
* @param {onTokenProgress} opts.onProgress - Callback function to handle token progress
|
||||
* @param {AbortController} opts.abortController - AbortController instance
|
||||
* @returns {Promise<string>}
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* Legacy completion handler for OpenAI API.
|
||||
* @callback getCompletion
|
||||
* @param {Array<ChatCompletionMessages> | string} input - Array of messages or a single prompt string
|
||||
* @param {(event: object | string) => Promise<void>} onProgress - SSE progress handler
|
||||
* @param {onTokenProgress} onTokenProgress - Token progress handler
|
||||
* @param {AbortController} [abortController] - AbortController instance
|
||||
* @returns {Promise<Object | string>} - Completion response
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* Cohere Stream handling. Note: abortController is not supported here.
|
||||
* @callback cohereChatCompletion
|
||||
* @param {object} params
|
||||
* @param {CohereChatStreamRequest | CohereChatRequest} params.payload
|
||||
* @param {onTokenProgress} params.onTokenProgress
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
const { CohereConstants } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Extracts a valid OpenAI baseURL from a given string, matching "url/v1," followed by an optional suffix.
|
||||
* The suffix can be one of several predefined values (e.g., 'openai', 'azure-openai', etc.),
|
||||
|
|
@ -19,6 +21,10 @@ function extractBaseURL(url) {
|
|||
return undefined;
|
||||
}
|
||||
|
||||
if (url.startsWith(CohereConstants.API_URL)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!url.includes('/v1')) {
|
||||
return url;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,6 +59,15 @@ const openAIModels = {
|
|||
'mistral-': 31990, // -10 from max
|
||||
};
|
||||
|
||||
const cohereModels = {
|
||||
'command-light': 4086, // -10 from max
|
||||
'command-light-nightly': 8182, // -10 from max
|
||||
command: 4086, // -10 from max
|
||||
'command-nightly': 8182, // -10 from max
|
||||
'command-r': 127500, // -500 from max
|
||||
'command-r-plus:': 127500, // -500 from max
|
||||
};
|
||||
|
||||
const googleModels = {
|
||||
/* Max I/O is combined so we subtract the amount from max response tokens for actual total */
|
||||
gemini: 32750, // -10 from max
|
||||
|
|
@ -83,11 +92,13 @@ const anthropicModels = {
|
|||
'claude-3-opus': 200000,
|
||||
};
|
||||
|
||||
const aggregateModels = { ...openAIModels, ...googleModels, ...anthropicModels, ...cohereModels };
|
||||
|
||||
// Order is important here: by model series and context size (gpt-4 then gpt-3, ascending)
|
||||
const maxTokensMap = {
|
||||
[EModelEndpoint.azureOpenAI]: openAIModels,
|
||||
[EModelEndpoint.openAI]: { ...openAIModels, ...googleModels, ...anthropicModels },
|
||||
[EModelEndpoint.custom]: { ...openAIModels, ...googleModels, ...anthropicModels },
|
||||
[EModelEndpoint.openAI]: aggregateModels,
|
||||
[EModelEndpoint.custom]: aggregateModels,
|
||||
[EModelEndpoint.google]: googleModels,
|
||||
[EModelEndpoint.anthropic]: anthropicModels,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue