🧠 feat: Cohere support as Custom Endpoint (#2328)

* chore: bump cohere-ai, fix firebase vulnerabilities by going down versions

* feat: cohere rates and context windows

* feat(createCoherePayload): transform openai payload for cohere compatibility

* feat: cohere backend support

* refactor(UnknownIcon): optimize icon render and add cohere

* docs: add cohere to Compatible AI Endpoints

* Update ai_endpoints.md
This commit is contained in:
Danny Avila 2024-04-05 15:19:41 -04:00 committed by GitHub
parent daa5f43ac6
commit cd7f3a51e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1007 additions and 622 deletions

View file

@ -23,7 +23,7 @@ class BaseClient {
throw new Error('Method \'setOptions\' must be implemented.');
}
getCompletion() {
async getCompletion() {
throw new Error('Method \'getCompletion\' must be implemented.');
}

View file

@ -3,10 +3,13 @@ const crypto = require('crypto');
const {
EModelEndpoint,
resolveHeaders,
CohereConstants,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { CohereClient } = require('cohere-ai');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { createCoherePayload } = require('./llm');
const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
@ -147,7 +150,8 @@ class ChatGPTClient extends BaseClient {
return tokenizer;
}
async getCompletion(input, onProgress, abortController = null) {
/** @type {getCompletion} */
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
@ -305,6 +309,11 @@ class ChatGPTClient extends BaseClient {
});
}
if (baseURL.startsWith(CohereConstants.API_URL)) {
const payload = createCoherePayload({ modelOptions });
return await this.cohereChatCompletion({ payload, onTokenProgress });
}
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
baseURL = baseURL.split('v1')[0] + 'v1/completions';
} else if (
@ -408,6 +417,35 @@ class ChatGPTClient extends BaseClient {
return response.json();
}
/** @type {cohereChatCompletion} */
async cohereChatCompletion({ payload, onTokenProgress }) {
const cohere = new CohereClient({
token: this.apiKey,
environment: this.completionsUrl,
});
if (!payload.stream) {
const chatResponse = await cohere.chat(payload);
return chatResponse.text;
}
const chatStream = await cohere.chatStream(payload);
let reply = '';
for await (const message of chatStream) {
if (!message) {
continue;
}
if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
} else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
}
return reply;
}
async generateTitle(userMessage, botMessage) {
const instructionsPayload = {
role: 'system',

View file

@ -5,6 +5,7 @@ const {
EModelEndpoint,
resolveHeaders,
ImageDetailCost,
CohereConstants,
getResponseSender,
validateVisionModel,
mapModelToAzureConfig,
@ -16,7 +17,13 @@ const {
getModelMaxTokens,
genAzureChatCompletion,
} = require('~/utils');
const { truncateText, formatMessage, createContextHandlers, CUT_OFF_PROMPT } = require('./prompts');
const {
truncateText,
formatMessage,
createContextHandlers,
CUT_OFF_PROMPT,
titleInstruction,
} = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { handleOpenAIErrors } = require('./tools/util');
const spendTokens = require('~/models/spendTokens');
@ -39,7 +46,10 @@ class OpenAIClient extends BaseClient {
super(apiKey, options);
this.ChatGPTClient = new ChatGPTClient();
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
/** @type {getCompletion} */
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
/** @type {cohereChatCompletion} */
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase()
: 'discard';
@ -48,6 +58,9 @@ class OpenAIClient extends BaseClient {
this.azure = options.azure || false;
this.setOptions(options);
this.metadata = {};
/** @type {string | undefined} - The API Completions URL */
this.completionsUrl;
}
// TODO: PluginsClient calls this 3x, unneeded
@ -533,6 +546,7 @@ class OpenAIClient extends BaseClient {
return result;
}
/** @type {sendCompletion} */
async sendCompletion(payload, opts = {}) {
let reply = '';
let result = null;
@ -541,7 +555,7 @@ class OpenAIClient extends BaseClient {
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
if (typeof opts.onProgress === 'function' && useOldMethod) {
await this.getCompletion(
const completionResult = await this.getCompletion(
payload,
(progressMessage) => {
if (progressMessage === '[DONE]') {
@ -574,8 +588,13 @@ class OpenAIClient extends BaseClient {
opts.onProgress(token);
reply += token;
},
opts.onProgress,
opts.abortController || new AbortController(),
);
if (completionResult && typeof completionResult === 'string') {
reply = completionResult;
}
} else if (typeof opts.onProgress === 'function' || this.options.useChatCompletion) {
reply = await this.chatCompletion({
payload,
@ -586,9 +605,14 @@ class OpenAIClient extends BaseClient {
result = await this.getCompletion(
payload,
null,
opts.onProgress,
opts.abortController || new AbortController(),
);
if (result && typeof result === 'string') {
return result.trim();
}
logger.debug('[OpenAIClient] sendCompletion: result', result);
if (this.isChatCompletion) {
@ -760,8 +784,7 @@ class OpenAIClient extends BaseClient {
const instructionsPayload = [
{
role: 'system',
content: `Detect user language and write in the same language an extremely concise title for this conversation, which you must accurately detect.
Write in the detected language. Title in 5 Words or Less. No Punctuation or Quotation. Do not mention the language. All first letters of every word should be capitalized and write the title in User Language only.
content: `Please generate ${titleInstruction}
${convo}
@ -770,8 +793,12 @@ ${convo}
];
try {
let useChatCompletion = true;
if (CohereConstants.API_URL) {
useChatCompletion = false;
}
title = (
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion: true })
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion })
).replaceAll('"', '');
} catch (e) {
logger.error(

View file

@ -0,0 +1,85 @@
const { CohereConstants } = require('librechat-data-provider');
const { titleInstruction } = require('../prompts/titlePrompts');
// Mapping OpenAI roles to Cohere roles
const roleMap = {
user: CohereConstants.ROLE_USER,
assistant: CohereConstants.ROLE_CHATBOT,
system: CohereConstants.ROLE_SYSTEM, // Recognize and map the system role explicitly
};
/**
* Adjusts an OpenAI ChatCompletionPayload to conform with Cohere's expected chat payload format.
* Now includes handling for "system" roles explicitly mentioned.
*
* @param {Object} options - Object containing the model options.
* @param {ChatCompletionPayload} options.modelOptions - The OpenAI model payload options.
* @returns {CohereChatStreamRequest} Cohere-compatible chat API payload.
*/
function createCoherePayload({ modelOptions }) {
/** @type {string | undefined} */
let preamble;
let latestUserMessageContent = '';
const {
stream,
stop,
top_p,
temperature,
frequency_penalty,
presence_penalty,
max_tokens,
messages,
model,
...rest
} = modelOptions;
// Filter out the latest user message and transform remaining messages to Cohere's chat_history format
let chatHistory = messages.reduce((acc, message, index, arr) => {
const isLastUserMessage = index === arr.length - 1 && message.role === 'user';
const messageContent =
typeof message.content === 'string'
? message.content
: message.content.map((part) => (part.type === 'text' ? part.text : '')).join(' ');
if (isLastUserMessage) {
latestUserMessageContent = messageContent;
} else {
acc.push({
role: roleMap[message.role] || CohereConstants.ROLE_USER,
message: messageContent,
});
}
return acc;
}, []);
if (
chatHistory.length === 1 &&
chatHistory[0].role === CohereConstants.ROLE_SYSTEM &&
!latestUserMessageContent.length
) {
const message = chatHistory[0].message;
latestUserMessageContent = message.includes(titleInstruction)
? CohereConstants.TITLE_MESSAGE
: '.';
preamble = message;
}
return {
message: latestUserMessageContent,
model: model,
chat_history: chatHistory,
stream: stream ?? false,
temperature: temperature,
frequency_penalty: frequency_penalty,
presence_penalty: presence_penalty,
max_tokens: max_tokens,
stop_sequences: stop,
preamble,
p: top_p,
...rest,
};
}
module.exports = createCoherePayload;

View file

@ -1,7 +1,9 @@
const createLLM = require('./createLLM');
const RunManager = require('./RunManager');
const createCoherePayload = require('./createCoherePayload');
module.exports = {
createLLM,
RunManager,
createCoherePayload,
};

View file

@ -27,6 +27,8 @@ ${convo}`,
return titlePrompt;
};
const titleInstruction =
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"';
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.
You may call them like this:
@ -51,7 +53,7 @@ Submit a brief title in the conversation's language, following the parameter des
<parameter>
<name>title</name>
<type>string</type>
<description>A concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"</description>
<description>${titleInstruction}</description>
</parameter>
</parameters>
</tool_description>
@ -80,6 +82,7 @@ function parseTitleFromPrompt(prompt) {
module.exports = {
langPrompt,
titleInstruction,
createTitlePrompt,
titleFunctionPrompt,
parseTitleFromPrompt,

View file

@ -3,6 +3,7 @@ const defaultRate = 6;
/**
* Mapping of model token sizes to their respective multipliers for prompt and completion.
* The rates are 1 USD per 1M tokens.
* @type {Object.<string, {prompt: number, completion: number}>}
*/
const tokenValues = {
@ -19,6 +20,11 @@ const tokenValues = {
'claude-2.1': { prompt: 8, completion: 24 },
'claude-2': { prompt: 8, completion: 24 },
'claude-': { prompt: 0.8, completion: 2.4 },
'command-r-plus': { prompt: 3, completion: 15 },
'command-r': { prompt: 0.5, completion: 1.5 },
/* cohere doesn't have rates for the older command models,
so this was from https://artificialanalysis.ai/models/command-light/providers */
command: { prompt: 0.38, completion: 0.38 },
};
/**

View file

@ -42,7 +42,7 @@
"axios": "^1.3.4",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
"cohere-ai": "^6.0.0",
"cohere-ai": "^7.9.1",
"connect-redis": "^7.1.0",
"cookie": "^0.5.0",
"cors": "^2.8.5",
@ -52,7 +52,7 @@
"express-rate-limit": "^6.9.0",
"express-session": "^1.17.3",
"file-type": "^18.7.0",
"firebase": "^10.8.0",
"firebase": "^10.6.0",
"googleapis": "^126.0.1",
"handlebars": "^4.7.7",
"html": "^1.0.0",

View file

@ -44,6 +44,30 @@
* @memberof typedefs
*/
/**
* @exports ChatCompletionPayload
* @typedef {import('openai').OpenAI.ChatCompletionCreateParams} ChatCompletionPayload
* @memberof typedefs
*/
/**
* @exports ChatCompletionMessages
* @typedef {import('openai').OpenAI.ChatCompletionMessageParam} ChatCompletionMessages
* @memberof typedefs
*/
/**
* @exports CohereChatStreamRequest
* @typedef {import('cohere-ai').Cohere.ChatStreamRequest} CohereChatStreamRequest
* @memberof typedefs
*/
/**
* @exports CohereChatRequest
* @typedef {import('cohere-ai').Cohere.ChatRequest} CohereChatRequest
* @memberof typedefs
*/
/**
* @exports OpenAIRequestOptions
* @typedef {import('openai').OpenAI.RequestOptions} OpenAIRequestOptions
@ -1062,3 +1086,44 @@
* @method handleMessageEvent Handles events related to messages within the run.
* @method messageCompleted Handles the completion of a message processing.
*/
/* Native app/client methods */
/**
* Accumulates tokens and sends them to the client for processing.
* @callback onTokenProgress
* @param {string} token - The current token generated by the model.
* @returns {Promise<void>}
* @memberof typedefs
*/
/**
* Main entrypoint for API completion calls
* @callback sendCompletion
* @param {Array<ChatCompletionMessages> | string} payload - The messages or prompt to send to the model
* @param {object} opts - Options for the completion
* @param {onTokenProgress} opts.onProgress - Callback function to handle token progress
* @param {AbortController} opts.abortController - AbortController instance
* @returns {Promise<string>}
* @memberof typedefs
*/
/**
* Legacy completion handler for OpenAI API.
* @callback getCompletion
* @param {Array<ChatCompletionMessages> | string} input - Array of messages or a single prompt string
* @param {(event: object | string) => Promise<void>} onProgress - SSE progress handler
* @param {onTokenProgress} onTokenProgress - Token progress handler
* @param {AbortController} [abortController] - AbortController instance
* @returns {Promise<Object | string>} - Completion response
* @memberof typedefs
*/
/**
* Cohere Stream handling. Note: abortController is not supported here.
* @callback cohereChatCompletion
* @param {object} params
* @param {CohereChatStreamRequest | CohereChatRequest} params.payload
* @param {onTokenProgress} params.onTokenProgress
* @memberof typedefs
*/

View file

@ -1,3 +1,5 @@
const { CohereConstants } = require('librechat-data-provider');
/**
* Extracts a valid OpenAI baseURL from a given string, matching "url/v1," followed by an optional suffix.
* The suffix can be one of several predefined values (e.g., 'openai', 'azure-openai', etc.),
@ -19,6 +21,10 @@ function extractBaseURL(url) {
return undefined;
}
if (url.startsWith(CohereConstants.API_URL)) {
return null;
}
if (!url.includes('/v1')) {
return url;
}

View file

@ -59,6 +59,15 @@ const openAIModels = {
'mistral-': 31990, // -10 from max
};
const cohereModels = {
'command-light': 4086, // -10 from max
'command-light-nightly': 8182, // -10 from max
command: 4086, // -10 from max
'command-nightly': 8182, // -10 from max
'command-r': 127500, // -500 from max
'command-r-plus:': 127500, // -500 from max
};
const googleModels = {
/* Max I/O is combined so we subtract the amount from max response tokens for actual total */
gemini: 32750, // -10 from max
@ -83,11 +92,13 @@ const anthropicModels = {
'claude-3-opus': 200000,
};
const aggregateModels = { ...openAIModels, ...googleModels, ...anthropicModels, ...cohereModels };
// Order is important here: by model series and context size (gpt-4 then gpt-3, ascending)
const maxTokensMap = {
[EModelEndpoint.azureOpenAI]: openAIModels,
[EModelEndpoint.openAI]: { ...openAIModels, ...googleModels, ...anthropicModels },
[EModelEndpoint.custom]: { ...openAIModels, ...googleModels, ...anthropicModels },
[EModelEndpoint.openAI]: aggregateModels,
[EModelEndpoint.custom]: aggregateModels,
[EModelEndpoint.google]: googleModels,
[EModelEndpoint.anthropic]: anthropicModels,
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

View file

@ -20,6 +20,13 @@ export type GenericSetter<T> = (value: T | ((currentValue: T) => T)) => void;
export type LastSelectedModels = Record<EModelEndpoint, string>;
export enum IconContext {
landing = 'landing',
menuItem = 'menu-item',
nav = 'nav',
message = 'message',
}
export type NavLink = {
title: string;
label?: string;

View file

@ -1,5 +1,44 @@
import { EModelEndpoint, KnownEndpoints } from 'librechat-data-provider';
import { CustomMinimalIcon } from '~/components/svg';
import { IconContext } from '~/common';
const knownEndpointAssets = {
[KnownEndpoints.mistral]: '/assets/mistral.png',
[KnownEndpoints.openrouter]: '/assets/openrouter.png',
[KnownEndpoints.groq]: '/assets/groq.png',
[KnownEndpoints.shuttleai]: '/assets/shuttleai.png',
[KnownEndpoints.anyscale]: '/assets/anyscale.png',
[KnownEndpoints.fireworks]: '/assets/fireworks.png',
[KnownEndpoints.ollama]: '/assets/ollama.png',
[KnownEndpoints.perplexity]: '/assets/perplexity.png',
[KnownEndpoints['together.ai']]: '/assets/together.png',
[KnownEndpoints.cohere]: '/assets/cohere.png',
};
const knownEndpointClasses = {
[KnownEndpoints.cohere]: {
[IconContext.landing]: 'p-2',
},
};
const getKnownClass = ({
currentEndpoint,
context = '',
className,
}: {
currentEndpoint: string;
context?: string;
className: string;
}) => {
if (currentEndpoint === KnownEndpoints.openrouter) {
return className;
}
const match = knownEndpointClasses[currentEndpoint]?.[context];
const defaultClass = context === IconContext.landing ? '' : className;
return match ?? defaultClass;
};
export default function UnknownIcon({
className = '',
@ -20,73 +59,23 @@ export default function UnknownIcon({
if (iconURL) {
return <img className={className} src={iconURL} alt={`${endpoint} Icon`} />;
} else if (currentEndpoint === KnownEndpoints.mistral) {
return (
<img
className={context === 'landing' ? '' : className}
src="/assets/mistral.png"
alt="Mistral AI Icon"
/>
);
} else if (currentEndpoint === KnownEndpoints.openrouter) {
return <img className={className} src="/assets/openrouter.png" alt="OpenRouter Icon" />;
} else if (currentEndpoint === KnownEndpoints.groq) {
return (
<img
className={context === 'landing' ? '' : className}
src="/assets/groq.png"
alt="Groq Cloud Icon"
/>
);
} else if (currentEndpoint === KnownEndpoints.shuttleai) {
return (
<img
className={context === 'landing' ? '' : className}
src="/assets/shuttleai.png"
alt="ShuttleAI Icon"
/>
);
} else if (currentEndpoint === KnownEndpoints.anyscale) {
return (
<img
className={context === 'landing' ? '' : className}
src="/assets/anyscale.png"
alt="Anyscale Icon"
/>
);
} else if (currentEndpoint === KnownEndpoints.fireworks) {
return (
<img
className={context === 'landing' ? '' : className}
src="/assets/fireworks.png"
alt="Fireworks Icon"
/>
);
} else if (currentEndpoint === KnownEndpoints.ollama) {
return (
<img
className={context === 'landing' ? '' : className}
src="/assets/ollama.png"
alt="Ollama Icon"
/>
);
} else if (currentEndpoint === KnownEndpoints.perplexity) {
return (
<img
className={context === 'landing' ? '' : className}
src="/assets/perplexity.png"
alt="Perplexity Icon"
/>
);
} else if (currentEndpoint === KnownEndpoints['together.ai']) {
return (
<img
className={context === 'landing' ? '' : className}
src="/assets/together.png"
alt="together.ai Icon"
/>
);
}
return <CustomMinimalIcon className={className} />;
const assetPath = knownEndpointAssets[currentEndpoint];
if (!assetPath) {
return <CustomMinimalIcon className={className} />;
}
return (
<img
className={getKnownClass({
currentEndpoint,
context: context,
className,
})}
src={assetPath}
alt={`${currentEndpoint} Icon`}
/>
);
}

View file

@ -14,6 +14,39 @@ In all of the examples, arbitrary environment variable names are defined but you
Some of the endpoints are marked as **Known,** which means they might have special handling and/or an icon already provided in the app for you.
## Cohere
> Cohere API key: [dashboard.cohere.com](https://dashboard.cohere.com/)
**Notes:**
- **Known:** icon provided.
- Experimental: does not follow OpenAI-spec, uses a new method for endpoint compatibility, shares some similarities and parameters.
- For a full list of Cohere-specific parameters, see the [Cohere API documentation](https://docs.cohere.com/reference/chat).
- Note: The following parameters are recognized between OpenAI and Cohere. Most are removed in the example config below to prefer Cohere's default settings:
- `stop`: mapped to `stop_sequences`
- `top_p`: mapped to `p`, different min/max values
- `frequency_penalty`: different min/max values
- `presence_penalty`: different min/max values
- `model`: shared, included by default.
- `stream`: shared, included by default.
- `max_tokens`: shared, not included by default.
```yaml
- name: "cohere"
apiKey: "${COHERE_API_KEY}"
baseURL: "https://api.cohere.ai/v1"
models:
default: ["command-r","command-r-plus","command-light","command-light-nightly","command","command-nightly"]
fetch: false
modelDisplayLabel: "cohere"
titleModel: "command"
dropParams: ["stop", "user", "frequency_penalty", "presence_penalty", "temperature", "top_p"]
```
![image](https://github.com/danny-avila/LibreChat/assets/110412045/03549e00-243c-4539-ac9a-0d782af7cd6c)
## Groq
> groq API key: [wow.groq.com](https://wow.groq.com/)

1170
package-lock.json generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
{
"name": "librechat-data-provider",
"version": "0.5.1",
"version": "0.5.2",
"description": "data services for librechat apps",
"main": "dist/index.js",
"module": "dist/index.es.js",

View file

@ -238,6 +238,7 @@ export enum KnownEndpoints {
ollama = 'ollama',
perplexity = 'perplexity',
'together.ai' = 'together.ai',
cohere = 'cohere',
}
export enum FetchTokenConfig {
@ -532,6 +533,32 @@ export enum Constants {
NO_PARENT = '00000000-0000-0000-0000-000000000000',
}
/**
* Enum for Cohere related constants
*/
export enum CohereConstants {
/**
* Cohere API Endpoint, for special handling
*/
API_URL = 'https://api.cohere.ai/v1',
/**
* Role for "USER" messages
*/
ROLE_USER = 'USER',
/**
* Role for "SYSTEM" messages
*/
ROLE_SYSTEM = 'SYSTEM',
/**
* Role for "CHATBOT" messages
*/
ROLE_CHATBOT = 'CHATBOT',
/**
* Title message as required by Cohere
*/
TITLE_MESSAGE = 'TITLE:',
}
export const defaultOrderQuery: {
order: 'desc';
limit: 100;