🧠 feat: Cohere support as Custom Endpoint (#2328)

* chore: bump cohere-ai, fix firebase vulnerabilities by going down versions

* feat: cohere rates and context windows

* feat(createCoherePayload): transform openai payload for cohere compatibility

* feat: cohere backend support

* refactor(UnknownIcon): optimize icon render and add cohere

* docs: add cohere to Compatible AI Endpoints

* Update ai_endpoints.md
This commit is contained in:
Danny Avila 2024-04-05 15:19:41 -04:00 committed by GitHub
parent daa5f43ac6
commit cd7f3a51e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1007 additions and 622 deletions

View file

@ -3,10 +3,13 @@ const crypto = require('crypto');
const {
EModelEndpoint,
resolveHeaders,
CohereConstants,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { CohereClient } = require('cohere-ai');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { createCoherePayload } = require('./llm');
const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
@ -147,7 +150,8 @@ class ChatGPTClient extends BaseClient {
return tokenizer;
}
async getCompletion(input, onProgress, abortController = null) {
/** @type {getCompletion} */
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
@ -305,6 +309,11 @@ class ChatGPTClient extends BaseClient {
});
}
if (baseURL.startsWith(CohereConstants.API_URL)) {
const payload = createCoherePayload({ modelOptions });
return await this.cohereChatCompletion({ payload, onTokenProgress });
}
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
baseURL = baseURL.split('v1')[0] + 'v1/completions';
} else if (
@ -408,6 +417,35 @@ class ChatGPTClient extends BaseClient {
return response.json();
}
/** @type {cohereChatCompletion} */
async cohereChatCompletion({ payload, onTokenProgress }) {
const cohere = new CohereClient({
token: this.apiKey,
environment: this.completionsUrl,
});
if (!payload.stream) {
const chatResponse = await cohere.chat(payload);
return chatResponse.text;
}
const chatStream = await cohere.chatStream(payload);
let reply = '';
for await (const message of chatStream) {
if (!message) {
continue;
}
if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
} else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
}
return reply;
}
async generateTitle(userMessage, botMessage) {
const instructionsPayload = {
role: 'system',