diff --git a/.env.example b/.env.example index 4a25380e1..fb722909b 100644 --- a/.env.example +++ b/.env.example @@ -105,10 +105,10 @@ DEBUG_OPENAI=false # OPENROUTER_API_KEY= #============# -# PaLM # +# Google # #============# -PALM_KEY=user_provided +GOOGLE_KEY=user_provided # GOOGLE_REVERSE_PROXY= #============# diff --git a/.github/workflows/playwright.yml b/.github/playwright.yml similarity index 100% rename from .github/workflows/playwright.yml rename to .github/playwright.yml diff --git a/README.md b/README.md index ea4ea390a..422452fca 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ - 🌎 Multilingual UI: - English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro, Русский - 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands - - 🤖 AI model selection: OpenAI API, Azure, BingAI, ChatGPT Browser, PaLM2, Anthropic (Claude), Plugins + - 🤖 AI model selection: OpenAI API, Azure, BingAI, ChatGPT, Google Vertex AI, Anthropic (Claude), Plugins - 💾 Create, Save, & Share Custom Presets - 🔄 Edit, Resubmit, and Continue messages with conversation branching - 📤 Export conversations as screenshots, markdown, text, json. diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 2d124c96d..595653f17 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -1,6 +1,6 @@ const Anthropic = require('@anthropic-ai/sdk'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); -const { getResponseSender, EModelEndpoint } = require('~/server/routes/endpoints/schemas'); +const { getResponseSender, EModelEndpoint } = require('~/server/services/Endpoints'); const { getModelMaxTokens } = require('~/utils'); const BaseClient = require('./BaseClient'); @@ -46,7 +46,8 @@ class AnthropicClient extends BaseClient { stop: modelOptions.stop, // no stop method for now }; - this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 100000; + this.maxContextTokens = + getModelMaxTokens(this.modelOptions.model, EModelEndpoint.anthropic) ?? 100000; this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1500; this.maxPromptTokens = this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 32aef7bf1..626a98888 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -445,6 +445,7 @@ class BaseClient { amount: promptTokens, debug: this.options.debug, model: this.modelOptions.model, + endpoint: this.options.endpoint, }, }); } diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index d3f77d069..4b4a02064 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -1,23 +1,43 @@ -const BaseClient = require('./BaseClient'); const { google } = require('googleapis'); const { Agent, ProxyAgent } = require('undici'); +const { GoogleVertexAI } = require('langchain/llms/googlevertexai'); +const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai'); +const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); +const { + getResponseSender, + EModelEndpoint, + endpointSettings, +} = require('~/server/services/Endpoints'); +const { getModelMaxTokens } = require('~/utils'); +const { formatMessage } = require('./prompts'); +const BaseClient = require('./BaseClient'); +const loc = 'us-central1'; +const publisher = 'google'; +const endpointPrefix = `https://${loc}-aiplatform.googleapis.com`; +// const apiEndpoint = loc + '-aiplatform.googleapis.com'; const tokenizersCache = {}; +const settings = endpointSettings[EModelEndpoint.google]; + class GoogleClient extends BaseClient { constructor(credentials, options = {}) { super('apiKey', options); + this.credentials = credentials; this.client_email = credentials.client_email; this.project_id = credentials.project_id; this.private_key = credentials.private_key; - this.sender = 'PaLM2'; + this.access_token = null; + if (options.skipSetOptions) { + return; + } this.setOptions(options); } - /* Google/PaLM2 specific methods */ + /* Google specific methods */ constructUrl() { - return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`; + return `${endpointPrefix}/v1/projects/${this.project_id}/locations/${loc}/publishers/${publisher}/models/${this.modelOptions.model}:serverStreamingPredict`; } async getClient() { @@ -35,6 +55,24 @@ class GoogleClient extends BaseClient { return jwtClient; } + async getAccessToken() { + const scopes = ['https://www.googleapis.com/auth/cloud-platform']; + const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes); + + return new Promise((resolve, reject) => { + jwtClient.authorize((err, tokens) => { + if (err) { + console.error('Error: jwtClient failed to authorize'); + console.error(err.message); + reject(err); + } else { + console.log('Access Token:', tokens.access_token); + resolve(tokens.access_token); + } + }); + }); + } + /* Required Client methods */ setOptions(options) { if (this.options && !this.options.replaceOptions) { @@ -53,30 +91,33 @@ class GoogleClient extends BaseClient { this.options = options; } - this.options.examples = this.options.examples.filter( - (obj) => obj.input.content !== '' && obj.output.content !== '', - ); + this.options.examples = this.options.examples + .filter((ex) => ex) + .filter((obj) => obj.input.content !== '' && obj.output.content !== ''); const modelOptions = this.options.modelOptions || {}; this.modelOptions = { ...modelOptions, // set some good defaults (check for undefined in some cases because they may be 0) - model: modelOptions.model || 'chat-bison', - temperature: typeof modelOptions.temperature === 'undefined' ? 0.2 : modelOptions.temperature, // 0 - 1, 0.2 is recommended - topP: typeof modelOptions.topP === 'undefined' ? 0.95 : modelOptions.topP, // 0 - 1, default: 0.95 - topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40 + model: modelOptions.model || settings.model.default, + temperature: + typeof modelOptions.temperature === 'undefined' + ? settings.temperature.default + : modelOptions.temperature, + topP: typeof modelOptions.topP === 'undefined' ? settings.topP.default : modelOptions.topP, + topK: typeof modelOptions.topK === 'undefined' ? settings.topK.default : modelOptions.topK, // stop: modelOptions.stop // no stop method for now }; - this.isChatModel = this.modelOptions.model.startsWith('chat-'); + this.isChatModel = this.modelOptions.model.includes('chat'); const { isChatModel } = this; - this.isTextModel = this.modelOptions.model.startsWith('text-'); + this.isTextModel = !isChatModel && /code|text/.test(this.modelOptions.model); const { isTextModel } = this; - this.maxContextTokens = this.options.maxContextTokens || (isTextModel ? 8000 : 4096); + this.maxContextTokens = getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google); // The max prompt tokens is determined by the max context tokens minus the max response tokens. // Earlier messages will be dropped until the prompt is within the limit. - this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1024; + this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default; this.maxPromptTokens = this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; @@ -88,6 +129,14 @@ class GoogleClient extends BaseClient { ); } + this.sender = + this.options.sender ?? + getResponseSender({ + model: this.modelOptions.model, + endpoint: EModelEndpoint.google, + modelLabel: this.options.modelLabel, + }); + this.userLabel = this.options.userLabel || 'User'; this.modelLabel = this.options.modelLabel || 'Assistant'; @@ -99,8 +148,8 @@ class GoogleClient extends BaseClient { this.endToken = ''; this.gptEncoder = this.constructor.getTokenizer('cl100k_base'); } else if (isTextModel) { - this.startToken = '<|im_start|>'; - this.endToken = '<|im_end|>'; + this.startToken = '||>'; + this.endToken = ''; this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, { '<|im_start|>': 100264, '<|im_end|>': 100265, @@ -138,15 +187,18 @@ class GoogleClient extends BaseClient { return this; } - getMessageMapMethod() { + formatMessages() { return ((message) => ({ author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel), content: message?.content ?? message.text, })).bind(this); } - buildMessages(messages = []) { - const formattedMessages = messages.map(this.getMessageMapMethod()); + buildMessages(messages = [], parentMessageId) { + if (this.isTextModel) { + return this.buildMessagesPrompt(messages, parentMessageId); + } + const formattedMessages = messages.map(this.formatMessages()); let payload = { instances: [ { @@ -164,15 +216,6 @@ class GoogleClient extends BaseClient { payload.instances[0].examples = this.options.examples; } - /* TO-DO: text model needs more context since it can't process an array of messages */ - if (this.isTextModel) { - payload.instances = [ - { - prompt: messages[messages.length - 1].content, - }, - ]; - } - if (this.options.debug) { console.debug('GoogleClient buildMessages'); console.dir(payload, { depth: null }); @@ -181,7 +224,157 @@ class GoogleClient extends BaseClient { return { prompt: payload }; } - async getCompletion(payload, abortController = null) { + async buildMessagesPrompt(messages, parentMessageId) { + const orderedMessages = this.constructor.getMessagesForConversation({ + messages, + parentMessageId, + }); + if (this.options.debug) { + console.debug('GoogleClient: orderedMessages', orderedMessages, parentMessageId); + } + + const formattedMessages = orderedMessages.map((message) => ({ + author: message.isCreatedByUser ? this.userLabel : this.modelLabel, + content: message?.content ?? message.text, + })); + + let lastAuthor = ''; + let groupedMessages = []; + + for (let message of formattedMessages) { + // If last author is not same as current author, add to new group + if (lastAuthor !== message.author) { + groupedMessages.push({ + author: message.author, + content: [message.content], + }); + lastAuthor = message.author; + // If same author, append content to the last group + } else { + groupedMessages[groupedMessages.length - 1].content.push(message.content); + } + } + + let identityPrefix = ''; + if (this.options.userLabel) { + identityPrefix = `\nHuman's name: ${this.options.userLabel}`; + } + + if (this.options.modelLabel) { + identityPrefix = `${identityPrefix}\nYou are ${this.options.modelLabel}`; + } + + let promptPrefix = (this.options.promptPrefix || '').trim(); + if (promptPrefix) { + // If the prompt prefix doesn't end with the end token, add it. + if (!promptPrefix.endsWith(`${this.endToken}`)) { + promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`; + } + promptPrefix = `\nContext:\n${promptPrefix}`; + } + + if (identityPrefix) { + promptPrefix = `${identityPrefix}${promptPrefix}`; + } + + // Prompt AI to respond, empty if last message was from AI + let isEdited = lastAuthor === this.modelLabel; + const promptSuffix = isEdited ? '' : `${promptPrefix}\n\n${this.modelLabel}:\n`; + let currentTokenCount = isEdited + ? this.getTokenCount(promptPrefix) + : this.getTokenCount(promptSuffix); + + let promptBody = ''; + const maxTokenCount = this.maxPromptTokens; + + const context = []; + + // Iterate backwards through the messages, adding them to the prompt until we reach the max token count. + // Do this within a recursive async function so that it doesn't block the event loop for too long. + // Also, remove the next message when the message that puts us over the token limit is created by the user. + // Otherwise, remove only the exceeding message. This is due to Anthropic's strict payload rule to start with "Human:". + const nextMessage = { + remove: false, + tokenCount: 0, + messageString: '', + }; + + const buildPromptBody = async () => { + if (currentTokenCount < maxTokenCount && groupedMessages.length > 0) { + const message = groupedMessages.pop(); + const isCreatedByUser = message.author === this.userLabel; + // Use promptPrefix if message is edited assistant' + const messagePrefix = + isCreatedByUser || !isEdited + ? `\n\n${message.author}:` + : `${promptPrefix}\n\n${message.author}:`; + const messageString = `${messagePrefix}\n${message.content}${this.endToken}\n`; + let newPromptBody = `${messageString}${promptBody}`; + + context.unshift(message); + + const tokenCountForMessage = this.getTokenCount(messageString); + const newTokenCount = currentTokenCount + tokenCountForMessage; + + if (!isCreatedByUser) { + nextMessage.messageString = messageString; + nextMessage.tokenCount = tokenCountForMessage; + } + + if (newTokenCount > maxTokenCount) { + if (!promptBody) { + // This is the first message, so we can't add it. Just throw an error. + throw new Error( + `Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`, + ); + } + + // Otherwise, ths message would put us over the token limit, so don't add it. + // if created by user, remove next message, otherwise remove only this message + if (isCreatedByUser) { + nextMessage.remove = true; + } + + return false; + } + promptBody = newPromptBody; + currentTokenCount = newTokenCount; + + // Switch off isEdited after using it for the first time + if (isEdited) { + isEdited = false; + } + + // wait for next tick to avoid blocking the event loop + await new Promise((resolve) => setImmediate(resolve)); + return buildPromptBody(); + } + return true; + }; + + await buildPromptBody(); + + if (nextMessage.remove) { + promptBody = promptBody.replace(nextMessage.messageString, ''); + currentTokenCount -= nextMessage.tokenCount; + context.shift(); + } + + let prompt = `${promptBody}${promptSuffix}`; + + // Add 2 tokens for metadata after all messages have been counted. + currentTokenCount += 2; + + // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. + this.modelOptions.maxOutputTokens = Math.min( + this.maxContextTokens - currentTokenCount, + this.maxResponseTokens, + ); + + return { prompt, context }; + } + + async _getCompletion(payload, abortController = null) { if (!abortController) { abortController = new AbortController(); } @@ -212,6 +405,72 @@ class GoogleClient extends BaseClient { return res.data; } + async getCompletion(_payload, options = {}) { + const { onProgress, abortController } = options; + const { parameters, instances } = _payload; + const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {}; + + let examples; + + let clientOptions = { + authOptions: { + credentials: { + ...this.credentials, + }, + projectId: this.project_id, + }, + ...parameters, + }; + + if (!parameters) { + clientOptions = { ...clientOptions, ...this.modelOptions }; + } + + if (_examples && _examples.length) { + examples = _examples + .map((ex) => { + const { input, output } = ex; + if (!input || !output) { + return undefined; + } + return { + input: new HumanMessage(input.content), + output: new AIMessage(output.content), + }; + }) + .filter((ex) => ex); + + clientOptions.examples = examples; + } + + const model = this.isTextModel + ? new GoogleVertexAI(clientOptions) + : new ChatGoogleVertexAI(clientOptions); + + let reply = ''; + const messages = this.isTextModel + ? _payload.trim() + : _messages + .map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' })) + .map((message) => formatMessage({ message, langChain: true })); + + if (context && messages?.length > 0) { + messages.unshift(new SystemMessage(context)); + } + + const stream = await model.stream(messages, { + signal: abortController.signal, + timeout: 7000, + }); + + for await (const chunk of stream) { + await this.generateTextStream(chunk?.content ?? chunk, onProgress, { delay: 7 }); + reply += chunk?.content ?? chunk; + } + + return reply; + } + getSaveOptions() { return { promptPrefix: this.options.promptPrefix, @@ -225,34 +484,18 @@ class GoogleClient extends BaseClient { } async sendCompletion(payload, opts = {}) { - console.log('GoogleClient: sendcompletion', payload, opts); let reply = ''; - let blocked = false; try { - const result = await this.getCompletion(payload, opts.abortController); - blocked = result?.predictions?.[0]?.safetyAttributes?.blocked; - reply = - result?.predictions?.[0]?.candidates?.[0]?.content || - result?.predictions?.[0]?.content || - ''; - if (blocked === true) { - reply = `Google blocked a proper response to your message:\n${JSON.stringify( - result.predictions[0].safetyAttributes, - )}${reply.length > 0 ? `\nAI Response:\n${reply}` : ''}`; - } + reply = await this.getCompletion(payload, opts); if (this.options.debug) { console.debug('result'); - console.debug(result); + console.debug(reply); } } catch (err) { console.error('Error: failed to send completion to Google'); + console.error(err); console.error(err.message); } - - if (!blocked) { - await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 }); - } - return reply.trim(); } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index cb0abdf99..0c7f508f8 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,10 +1,10 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); +const { getResponseSender, EModelEndpoint } = require('~/server/services/Endpoints'); const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images'); const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('~/utils'); const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts'); -const { getResponseSender, EModelEndpoint } = require('~/server/routes/endpoints/schemas'); const { handleOpenAIErrors } = require('./tools/util'); const spendTokens = require('~/models/spendTokens'); const { createLLM, RunManager } = require('./llm'); diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 4b4919e33..058c22212 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -3,11 +3,12 @@ const { CallbackManager } = require('langchain/callbacks'); const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents'); const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers'); -const checkBalance = require('../../models/checkBalance'); +const { EModelEndpoint } = require('~/server/services/Endpoints'); const { formatLangChainMessages } = require('./prompts'); -const { isEnabled } = require('../../server/utils'); -const { extractBaseURL } = require('../../utils'); +const checkBalance = require('~/models/checkBalance'); const { SelfReflectionTool } = require('./tools'); +const { isEnabled } = require('~/server/utils'); +const { extractBaseURL } = require('~/utils'); const { loadTools } = require('./tools/util'); class PluginsClient extends OpenAIClient { @@ -304,6 +305,7 @@ class PluginsClient extends OpenAIClient { amount: promptTokens, debug: this.options.debug, model: this.modelOptions.model, + endpoint: EModelEndpoint.openAI, }, }); } diff --git a/api/app/clients/callbacks/createStartHandler.js b/api/app/clients/callbacks/createStartHandler.js index e7137abfc..caf351a65 100644 --- a/api/app/clients/callbacks/createStartHandler.js +++ b/api/app/clients/callbacks/createStartHandler.js @@ -1,7 +1,8 @@ const { promptTokensEstimate } = require('openai-chat-tokens'); -const checkBalance = require('../../../models/checkBalance'); -const { isEnabled } = require('../../../server/utils'); -const { formatFromLangChain } = require('../prompts'); +const { EModelEndpoint } = require('~/server/services/Endpoints'); +const { formatFromLangChain } = require('~/app/clients/prompts'); +const checkBalance = require('~/models/checkBalance'); +const { isEnabled } = require('~/server/utils'); const createStartHandler = ({ context, @@ -55,6 +56,7 @@ const createStartHandler = ({ debug: manager.debug, generations, model, + endpoint: EModelEndpoint.openAI, }, }); } diff --git a/api/app/clients/prompts/formatGoogleInputs.js b/api/app/clients/prompts/formatGoogleInputs.js new file mode 100644 index 000000000..c929df8b5 --- /dev/null +++ b/api/app/clients/prompts/formatGoogleInputs.js @@ -0,0 +1,42 @@ +/** + * Formats an object to match the struct_val, list_val, string_val, float_val, and int_val format. + * + * @param {Object} obj - The object to be formatted. + * @returns {Object} The formatted object. + * + * Handles different types: + * - Arrays are wrapped in list_val and each element is processed. + * - Objects are recursively processed. + * - Strings are wrapped in string_val. + * - Numbers are wrapped in float_val or int_val depending on whether they are floating-point or integers. + */ +function formatGoogleInputs(obj) { + const formattedObj = {}; + + for (const key in obj) { + if (Object.prototype.hasOwnProperty.call(obj, key)) { + const value = obj[key]; + + // Handle arrays + if (Array.isArray(value)) { + formattedObj[key] = { list_val: value.map((item) => formatGoogleInputs(item)) }; + } + // Handle objects + else if (typeof value === 'object' && value !== null) { + formattedObj[key] = formatGoogleInputs(value); + } + // Handle numbers + else if (typeof value === 'number') { + formattedObj[key] = Number.isInteger(value) ? { int_val: value } : { float_val: value }; + } + // Handle other types (e.g., strings) + else { + formattedObj[key] = { string_val: [value] }; + } + } + } + + return { struct_val: formattedObj }; +} + +module.exports = formatGoogleInputs; diff --git a/api/app/clients/prompts/formatGoogleInputs.spec.js b/api/app/clients/prompts/formatGoogleInputs.spec.js new file mode 100644 index 000000000..8fef9dfb5 --- /dev/null +++ b/api/app/clients/prompts/formatGoogleInputs.spec.js @@ -0,0 +1,274 @@ +const formatGoogleInputs = require('./formatGoogleInputs'); + +describe('formatGoogleInputs', () => { + it('formats message correctly', () => { + const input = { + messages: [ + { + content: 'hi', + author: 'user', + }, + ], + context: 'context', + examples: [ + { + input: { + author: 'user', + content: 'user input', + }, + output: { + author: 'bot', + content: 'bot output', + }, + }, + ], + parameters: { + temperature: 0.2, + topP: 0.8, + topK: 40, + maxOutputTokens: 1024, + }, + }; + + const expectedOutput = { + struct_val: { + messages: { + list_val: [ + { + struct_val: { + content: { + string_val: ['hi'], + }, + author: { + string_val: ['user'], + }, + }, + }, + ], + }, + context: { + string_val: ['context'], + }, + examples: { + list_val: [ + { + struct_val: { + input: { + struct_val: { + author: { + string_val: ['user'], + }, + content: { + string_val: ['user input'], + }, + }, + }, + output: { + struct_val: { + author: { + string_val: ['bot'], + }, + content: { + string_val: ['bot output'], + }, + }, + }, + }, + }, + ], + }, + parameters: { + struct_val: { + temperature: { + float_val: 0.2, + }, + topP: { + float_val: 0.8, + }, + topK: { + int_val: 40, + }, + maxOutputTokens: { + int_val: 1024, + }, + }, + }, + }, + }; + + const result = formatGoogleInputs(input); + expect(JSON.stringify(result)).toEqual(JSON.stringify(expectedOutput)); + }); + + it('formats real payload parts', () => { + const input = { + instances: [ + { + context: 'context', + examples: [ + { + input: { + author: 'user', + content: 'user input', + }, + output: { + author: 'bot', + content: 'user output', + }, + }, + ], + messages: [ + { + author: 'user', + content: 'hi', + }, + ], + }, + ], + parameters: { + candidateCount: 1, + maxOutputTokens: 1024, + temperature: 0.2, + topP: 0.8, + topK: 40, + }, + }; + const expectedOutput = { + struct_val: { + instances: { + list_val: [ + { + struct_val: { + context: { string_val: ['context'] }, + examples: { + list_val: [ + { + struct_val: { + input: { + struct_val: { + author: { string_val: ['user'] }, + content: { string_val: ['user input'] }, + }, + }, + output: { + struct_val: { + author: { string_val: ['bot'] }, + content: { string_val: ['user output'] }, + }, + }, + }, + }, + ], + }, + messages: { + list_val: [ + { + struct_val: { + author: { string_val: ['user'] }, + content: { string_val: ['hi'] }, + }, + }, + ], + }, + }, + }, + ], + }, + parameters: { + struct_val: { + candidateCount: { int_val: 1 }, + maxOutputTokens: { int_val: 1024 }, + temperature: { float_val: 0.2 }, + topP: { float_val: 0.8 }, + topK: { int_val: 40 }, + }, + }, + }, + }; + + const result = formatGoogleInputs(input); + expect(JSON.stringify(result)).toEqual(JSON.stringify(expectedOutput)); + }); + + it('helps create valid payload parts', () => { + const instances = { + context: 'context', + examples: [ + { + input: { + author: 'user', + content: 'user input', + }, + output: { + author: 'bot', + content: 'user output', + }, + }, + ], + messages: [ + { + author: 'user', + content: 'hi', + }, + ], + }; + + const expectedInstances = { + struct_val: { + context: { string_val: ['context'] }, + examples: { + list_val: [ + { + struct_val: { + input: { + struct_val: { + author: { string_val: ['user'] }, + content: { string_val: ['user input'] }, + }, + }, + output: { + struct_val: { + author: { string_val: ['bot'] }, + content: { string_val: ['user output'] }, + }, + }, + }, + }, + ], + }, + messages: { + list_val: [ + { + struct_val: { + author: { string_val: ['user'] }, + content: { string_val: ['hi'] }, + }, + }, + ], + }, + }, + }; + + const parameters = { + candidateCount: 1, + maxOutputTokens: 1024, + temperature: 0.2, + topP: 0.8, + topK: 40, + }; + const expectedParameters = { + struct_val: { + candidateCount: { int_val: 1 }, + maxOutputTokens: { int_val: 1024 }, + temperature: { float_val: 0.2 }, + topP: { float_val: 0.8 }, + topK: { int_val: 40 }, + }, + }; + + const instancesResult = formatGoogleInputs(instances); + const parametersResult = formatGoogleInputs(parameters); + expect(JSON.stringify(instancesResult)).toEqual(JSON.stringify(expectedInstances)); + expect(JSON.stringify(parametersResult)).toEqual(JSON.stringify(expectedParameters)); + }); +}); diff --git a/api/models/Balance.js b/api/models/Balance.js index 3d94aa013..f0e6d73d1 100644 --- a/api/models/Balance.js +++ b/api/models/Balance.js @@ -2,8 +2,16 @@ const mongoose = require('mongoose'); const balanceSchema = require('./schema/balance'); const { getMultiplier } = require('./tx'); -balanceSchema.statics.check = async function ({ user, model, valueKey, tokenType, amount, debug }) { - const multiplier = getMultiplier({ valueKey, tokenType, model }); +balanceSchema.statics.check = async function ({ + user, + model, + endpoint, + valueKey, + tokenType, + amount, + debug, +}) { + const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint }); const tokenCost = amount * multiplier; const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {}; @@ -11,6 +19,7 @@ balanceSchema.statics.check = async function ({ user, model, valueKey, tokenType console.log('balance check', { user, model, + endpoint, valueKey, tokenType, amount, diff --git a/api/models/tx.js b/api/models/tx.js index 339c1e340..f6f3b7f55 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -18,10 +18,11 @@ const tokenValues = { * Retrieves the key associated with a given model name. * * @param {string} model - The model name to match. + * @param {string} endpoint - The endpoint name to match. * @returns {string|undefined} The key corresponding to the model name, or undefined if no match is found. */ -const getValueKey = (model) => { - const modelName = matchModelName(model); +const getValueKey = (model, endpoint) => { + const modelName = matchModelName(model, endpoint); if (!modelName) { return undefined; } @@ -51,9 +52,10 @@ const getValueKey = (model) => { * @param {string} [params.valueKey] - The key corresponding to the model name. * @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion'). * @param {string} [params.model] - The model name to derive the value key from if not provided. + * @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided. * @returns {number} The multiplier for the given parameters, or a default value if not found. */ -const getMultiplier = ({ valueKey, tokenType, model }) => { +const getMultiplier = ({ valueKey, tokenType, model, endpoint }) => { if (valueKey && tokenType) { return tokenValues[valueKey][tokenType] ?? defaultRate; } @@ -62,7 +64,7 @@ const getMultiplier = ({ valueKey, tokenType, model }) => { return 1; } - valueKey = getValueKey(model); + valueKey = getValueKey(model, endpoint); if (!valueKey) { return defaultRate; } diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js new file mode 100644 index 000000000..30487a277 --- /dev/null +++ b/api/server/controllers/AskController.js @@ -0,0 +1,132 @@ +const { sendMessage, createOnProgress } = require('~/server/utils'); +const { saveMessage, getConvoTitle, getConvo } = require('~/models'); +const { getResponseSender } = require('~/server/services/Endpoints'); +const { createAbortController, handleAbortError } = require('~/server/middleware'); + +const AskController = async (req, res, next, initializeClient) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let promptTokens; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); + const user = req.user.id; + + const getReqData = (data = {}) => { + for (let key in data) { + if (key === 'userMessage') { + userMessage = data[key]; + userMessageId = data[key].messageId; + } else if (key === 'responseMessageId') { + responseMessageId = data[key]; + } else if (key === 'promptTokens') { + promptTokens = data[key]; + } else if (!conversationId && key === 'conversationId') { + conversationId = data[key]; + } + } + }; + + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender, + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + unfinished: true, + cancelled: false, + error: false, + user, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + try { + const addMetadata = (data) => (metadata = data); + const getAbortData = () => ({ + conversationId, + messageId: responseMessageId, + sender, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + promptTokens, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + const { client } = await initializeClient({ req, res, endpointOption }); + + let response = await client.sendMessage(text, { + // debug: true, + user, + conversationId, + parentMessageId, + overrideParentMessageId, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId ?? userMessageId, + }), + onStart, + getReqData, + addMetadata, + abortController, + }); + + if (metadata) { + response = { ...response, ...metadata }; + } + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + sendMessage(res, { + title: await getConvoTitle(user, conversationId), + final: true, + conversation: await getConvo(user, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + await saveMessage({ ...response, user }); + await saveMessage(userMessage); + + // TODO: add title service + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender, + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}; + +module.exports = AskController; diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js new file mode 100644 index 000000000..6b6b86428 --- /dev/null +++ b/api/server/controllers/EditController.js @@ -0,0 +1,135 @@ +const { sendMessage, createOnProgress } = require('~/server/utils'); +const { saveMessage, getConvoTitle, getConvo } = require('~/models'); +const { getResponseSender } = require('~/server/services/Endpoints'); +const { createAbortController, handleAbortError } = require('~/server/middleware'); + +const EditController = async (req, res, next, initializeClient) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let promptTokens; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); + const userMessageId = parentMessageId; + const user = req.user.id; + + const addMetadata = (data) => (metadata = data); + const getReqData = (data = {}) => { + for (let key in data) { + if (key === 'userMessage') { + userMessage = data[key]; + } else if (key === 'responseMessageId') { + responseMessageId = data[key]; + } else if (key === 'promptTokens') { + promptTokens = data[key]; + } + } + }; + + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender, + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + unfinished: true, + cancelled: false, + isEdited: true, + error: false, + user, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + try { + const getAbortData = () => ({ + conversationId, + messageId: responseMessageId, + sender, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + promptTokens, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + const { client } = await initializeClient({ req, res, endpointOption }); + + let response = await client.sendMessage(text, { + user, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId ?? userMessageId, + }), + getReqData, + onStart, + addMetadata, + abortController, + }); + + if (metadata) { + response = { ...response, ...metadata }; + } + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + sendMessage(res, { + title: await getConvoTitle(user, conversationId), + final: true, + conversation: await getConvo(user, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + await saveMessage({ ...response, user }); + await saveMessage(userMessage); + + // TODO: add title service + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender, + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}; + +module.exports = EditController; diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index b50b0f42c..1eec09bb9 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -1,14 +1,16 @@ -const openAI = require('~/server/routes/endpoints/openAI'); -const gptPlugins = require('~/server/routes/endpoints/gptPlugins'); -const anthropic = require('~/server/routes/endpoints/anthropic'); -const { parseConvo, EModelEndpoint } = require('~/server/routes/endpoints/schemas'); const { processFiles } = require('~/server/services/Files'); +const openAI = require('~/server/services/Endpoints/openAI'); +const google = require('~/server/services/Endpoints/google'); +const anthropic = require('~/server/services/Endpoints/anthropic'); +const gptPlugins = require('~/server/services/Endpoints/gptPlugins'); +const { parseConvo, EModelEndpoint } = require('~/server/services/Endpoints'); const buildFunction = { [EModelEndpoint.openAI]: openAI.buildOptions, + [EModelEndpoint.google]: google.buildOptions, [EModelEndpoint.azureOpenAI]: openAI.buildOptions, - [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, [EModelEndpoint.anthropic]: anthropic.buildOptions, + [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, }; function buildEndpointOption(req, res, next) { diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js index 1f44e2974..f51e4c2c7 100644 --- a/api/server/middleware/denyRequest.js +++ b/api/server/middleware/denyRequest.js @@ -1,7 +1,7 @@ const crypto = require('crypto'); -const { sendMessage, sendError } = require('../utils'); -const { getResponseSender } = require('../routes/endpoints/schemas'); -const { saveMessage } = require('../../models'); +const { saveMessage } = require('~/models'); +const { sendMessage, sendError } = require('~/server/utils'); +const { getResponseSender } = require('~/server/services/Endpoints'); /** * Denies a request by sending an error message and optionally saves the user's message. diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js index 85d47972e..e0ea0f985 100644 --- a/api/server/routes/ask/anthropic.js +++ b/api/server/routes/ask/anthropic.js @@ -1,137 +1,19 @@ const express = require('express'); -const router = express.Router(); -const { getResponseSender } = require('../endpoints/schemas'); -const { initializeClient } = require('../endpoints/anthropic'); +const AskController = require('~/server/controllers/AskController'); +const { initializeClient } = require('~/server/services/Endpoints/anthropic'); const { - handleAbort, - createAbortController, - handleAbortError, setHeaders, + handleAbort, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); -const { saveMessage, getConvoTitle, getConvo } = require('~/models'); -const { sendMessage, createOnProgress } = require('~/server/utils'); + +const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - let userMessage; - let promptTokens; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); - const user = req.user.id; - - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - userMessageId = data[key].messageId; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } else if (!conversationId && key === 'conversationId') { - conversationId = data[key]; - } - } - }; - - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender, - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - unfinished: true, - cancelled: false, - error: false, - user, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - try { - const getAbortData = () => ({ - conversationId, - messageId: responseMessageId, - sender, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - promptTokens, - }); - - const { abortController, onStart } = createAbortController(req, res, getAbortData); - - const { client } = await initializeClient({ req, res, endpointOption }); - - let response = await client.sendMessage(text, { - getReqData, - // debug: true, - user, - conversationId, - parentMessageId, - overrideParentMessageId, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId ?? userMessageId, - }), - onStart, - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - await saveMessage({ ...response, user }); - await saveMessage(userMessage); - - // TODO: add anthropic titling - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { + await AskController(req, res, next, initializeClient); }); module.exports = router; diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js index 1011e173e..78c648495 100644 --- a/api/server/routes/ask/google.js +++ b/api/server/routes/ask/google.js @@ -1,181 +1,19 @@ const express = require('express'); +const AskController = require('~/server/controllers/AskController'); +const { initializeClient } = require('~/server/services/Endpoints/google'); +const { + setHeaders, + handleAbort, + validateEndpoint, + buildEndpointOption, +} = require('~/server/middleware'); + const router = express.Router(); -const crypto = require('crypto'); -const { GoogleClient } = require('../../../app'); -const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); -const { handleError, sendMessage, createOnProgress } = require('../../utils'); -const { getUserKey, checkUserKeyExpiry } = require('../../services/UserService'); -const { setHeaders } = require('../../middleware'); -router.post('/', setHeaders, async (req, res) => { - const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body; - if (text.length === 0) { - return handleError(res, { text: 'Prompt empty or too short' }); - } - if (endpoint !== 'google') { - return handleError(res, { text: 'Illegal request' }); - } +router.post('/abort', handleAbort()); - // build endpoint option - const endpointOption = { - examples: req.body?.examples ?? [{ input: { content: '' }, output: { content: '' } }], - promptPrefix: req.body?.promptPrefix ?? null, - key: req.body?.key ?? null, - modelOptions: { - model: req.body?.model ?? 'chat-bison', - modelLabel: req.body?.modelLabel ?? null, - temperature: req.body?.temperature ?? 0.2, - maxOutputTokens: req.body?.maxOutputTokens ?? 1024, - topP: req.body?.topP ?? 0.95, - topK: req.body?.topK ?? 40, - }, - }; - - const availableModels = ['chat-bison', 'text-bison', 'codechat-bison']; - if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) { - return handleError(res, { text: 'Illegal request: model' }); - } - - const conversationId = oldConversationId || crypto.randomUUID(); - - // eslint-disable-next-line no-use-before-define - return await ask({ - text, - endpointOption, - conversationId, - parentMessageId, - req, - res, - }); +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { + await AskController(req, res, next, initializeClient); }); -const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => { - let userMessage; - let userMessageId; - // let promptTokens; - let responseMessageId; - let lastSavedTimestamp = 0; - const { overrideParentMessageId = null } = req.body; - const user = req.user.id; - - try { - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - userMessageId = data[key].messageId; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - // } else if (key === 'promptTokens') { - // promptTokens = data[key]; - } else if (!conversationId && key === 'conversationId') { - conversationId = data[key]; - } - } - - sendMessage(res, { message: userMessage, created: true }); - }; - - const { onProgress: progressCallback } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > 500) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: 'PaLM2', - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - unfinished: true, - cancelled: false, - error: false, - user, - }); - } - }, - }); - - const abortController = new AbortController(); - - const isUserProvided = process.env.PALM_KEY === 'user_provided'; - - let key; - if (endpointOption.key && isUserProvided) { - checkUserKeyExpiry( - endpointOption.key, - 'Your GOOGLE_TOKEN has expired. Please provide your token again.', - ); - key = await getUserKey({ userId: user, name: 'google' }); - key = JSON.parse(key); - delete endpointOption.key; - console.log('Using service account key provided by User for PaLM models'); - } - - try { - key = require('../../../data/auth.json'); - } catch (e) { - console.log('No \'auth.json\' file (service account key) found in /api/data/ for PaLM models'); - } - - const clientOptions = { - // debug: true, // for testing - reverseProxyUrl: process.env.GOOGLE_REVERSE_PROXY || null, - proxy: process.env.PROXY || null, - ...endpointOption, - }; - - const client = new GoogleClient(key, clientOptions); - - let response = await client.sendMessage(text, { - getReqData, - user, - conversationId, - parentMessageId, - overrideParentMessageId, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId || userMessageId, - }), - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - await saveConvo(user, { - ...endpointOption, - ...endpointOption.modelOptions, - conversationId, - endpoint: 'google', - }); - - await saveMessage({ ...response, user }); - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - } catch (error) { - console.error(error); - const errorMessage = { - messageId: responseMessageId, - sender: 'PaLM2', - conversationId, - parentMessageId, - unfinished: false, - cancelled: false, - error: true, - text: error.message, - }; - await saveMessage({ ...errorMessage, user }); - handleError(res, errorMessage); - } -}; - module.exports = router; diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 448558817..b54b516c4 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -1,11 +1,11 @@ const express = require('express'); const router = express.Router(); -const { getResponseSender } = require('../endpoints/schemas'); -const { validateTools } = require('../../../app'); -const { addTitle } = require('../endpoints/openAI'); -const { initializeClient } = require('../endpoints/gptPlugins'); -const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); -const { sendMessage, createOnProgress } = require('../../utils'); +const { getResponseSender } = require('~/server/services/Endpoints'); +const { validateTools } = require('~/app'); +const { addTitle } = require('~/server/services/Endpoints/openAI'); +const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); +const { saveMessage, getConvoTitle, getConvo } = require('~/models'); +const { sendMessage, createOnProgress } = require('~/server/utils'); const { handleAbort, createAbortController, @@ -13,7 +13,7 @@ const { setHeaders, validateEndpoint, buildEndpointOption, -} = require('../../middleware'); +} = require('~/server/middleware'); router.post('/abort', handleAbort()); diff --git a/api/server/routes/ask/index.js b/api/server/routes/ask/index.js index e13f20195..66da4edc3 100644 --- a/api/server/routes/ask/index.js +++ b/api/server/routes/ask/index.js @@ -1,11 +1,12 @@ const express = require('express'); -const router = express.Router(); const openAI = require('./openAI'); const google = require('./google'); const bingAI = require('./bingAI'); +const anthropic = require('./anthropic'); const gptPlugins = require('./gptPlugins'); const askChatGPTBrowser = require('./askChatGPTBrowser'); -const anthropic = require('./anthropic'); +const { isEnabled } = require('~/server/utils'); +const { EModelEndpoint } = require('~/server/services/Endpoints'); const { uaParser, checkBan, @@ -13,12 +14,12 @@ const { concurrentLimiter, messageIpLimiter, messageUserLimiter, -} = require('../../middleware'); -const { isEnabled } = require('../../utils'); -const { EModelEndpoint } = require('../endpoints/schemas'); +} = require('~/server/middleware'); const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; +const router = express.Router(); + router.use(requireJwtAuth); router.use(checkBan); router.use(uaParser); @@ -36,10 +37,10 @@ if (isEnabled(LIMIT_MESSAGE_USER)) { } router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI); -router.use(`/${EModelEndpoint.google}`, google); -router.use(`/${EModelEndpoint.bingAI}`, bingAI); router.use(`/${EModelEndpoint.chatGPTBrowser}`, askChatGPTBrowser); router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins); router.use(`/${EModelEndpoint.anthropic}`, anthropic); +router.use(`/${EModelEndpoint.google}`, google); +router.use(`/${EModelEndpoint.bingAI}`, bingAI); module.exports = router; diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js index 43c145b52..6eb8f6615 100644 --- a/api/server/routes/ask/openAI.js +++ b/api/server/routes/ask/openAI.js @@ -2,8 +2,8 @@ const express = require('express'); const router = express.Router(); const { sendMessage, createOnProgress } = require('~/server/utils'); const { saveMessage, getConvoTitle, getConvo } = require('~/models'); -const { getResponseSender } = require('~/server/routes/endpoints/schemas'); -const { addTitle, initializeClient } = require('~/server/routes/endpoints/openAI'); +const { getResponseSender } = require('~/server/services/Endpoints'); +const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI'); const { handleAbort, createAbortController, diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js index 34c659cc7..34dd9d6df 100644 --- a/api/server/routes/edit/anthropic.js +++ b/api/server/routes/edit/anthropic.js @@ -1,147 +1,19 @@ const express = require('express'); -const router = express.Router(); -const { getResponseSender } = require('../endpoints/schemas'); -const { initializeClient } = require('../endpoints/anthropic'); +const EditController = require('~/server/controllers/EditController'); +const { initializeClient } = require('~/server/services/Endpoints/anthropic'); const { - handleAbort, - createAbortController, - handleAbortError, setHeaders, + handleAbort, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); -const { saveMessage, getConvoTitle, getConvo } = require('~/models'); -const { sendMessage, createOnProgress } = require('~/server/utils'); + +const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('edit log'); - console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let promptTokens; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); - const userMessageId = parentMessageId; - const user = req.user.id; - - const addMetadata = (data) => (metadata = data); - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } - } - }; - - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender, - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - unfinished: true, - cancelled: false, - isEdited: true, - error: false, - user, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - try { - const getAbortData = () => ({ - conversationId, - messageId: responseMessageId, - sender, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - promptTokens, - }); - - const { abortController, onStart } = createAbortController(req, res, getAbortData); - - const { client } = await initializeClient({ req, res, endpointOption }); - - let response = await client.sendMessage(text, { - user, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId ?? userMessageId, - }), - getReqData, - onStart, - addMetadata, - abortController, - }); - - if (metadata) { - response = { ...response, ...metadata }; - } - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - await saveMessage({ ...response, user }); - await saveMessage(userMessage); - - // TODO: add anthropic titling - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { + await EditController(req, res, next, initializeClient); }); module.exports = router; diff --git a/api/server/routes/edit/google.js b/api/server/routes/edit/google.js new file mode 100644 index 000000000..e4dfbcd14 --- /dev/null +++ b/api/server/routes/edit/google.js @@ -0,0 +1,19 @@ +const express = require('express'); +const EditController = require('~/server/controllers/EditController'); +const { initializeClient } = require('~/server/services/Endpoints/google'); +const { + setHeaders, + handleAbort, + validateEndpoint, + buildEndpointOption, +} = require('~/server/middleware'); + +const router = express.Router(); + +router.post('/abort', handleAbort()); + +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { + await EditController(req, res, next, initializeClient); +}); + +module.exports = router; diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index 4f1843d3e..451b3a5b5 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -1,10 +1,10 @@ const express = require('express'); const router = express.Router(); -const { getResponseSender } = require('../endpoints/schemas'); -const { validateTools } = require('../../../app'); -const { initializeClient } = require('../endpoints/gptPlugins'); -const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); -const { sendMessage, createOnProgress, formatSteps, formatAction } = require('../../utils'); +const { validateTools } = require('~/app'); +const { saveMessage, getConvoTitle, getConvo } = require('~/models'); +const { getResponseSender } = require('~/server/services/Endpoints'); +const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); +const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); const { handleAbort, createAbortController, @@ -12,7 +12,7 @@ const { setHeaders, validateEndpoint, buildEndpointOption, -} = require('../../middleware'); +} = require('~/server/middleware'); router.post('/abort', handleAbort()); diff --git a/api/server/routes/edit/index.js b/api/server/routes/edit/index.js index dcf5ff553..09598f70c 100644 --- a/api/server/routes/edit/index.js +++ b/api/server/routes/edit/index.js @@ -1,20 +1,23 @@ const express = require('express'); -const router = express.Router(); const openAI = require('./openAI'); -const gptPlugins = require('./gptPlugins'); +const google = require('./google'); const anthropic = require('./anthropic'); +const gptPlugins = require('./gptPlugins'); +const { isEnabled } = require('~/server/utils'); +const { EModelEndpoint } = require('~/server/services/Endpoints'); const { checkBan, uaParser, requireJwtAuth, - concurrentLimiter, messageIpLimiter, + concurrentLimiter, messageUserLimiter, -} = require('../../middleware'); -const { isEnabled } = require('../../utils'); +} = require('~/server/middleware'); const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; +const router = express.Router(); + router.use(requireJwtAuth); router.use(checkBan); router.use(uaParser); @@ -31,8 +34,9 @@ if (isEnabled(LIMIT_MESSAGE_USER)) { router.use(messageUserLimiter); } -router.use(['/azureOpenAI', '/openAI'], openAI); -router.use('/gptPlugins', gptPlugins); -router.use('/anthropic', anthropic); +router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI); +router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins); +router.use(`/${EModelEndpoint.anthropic}`, anthropic); +router.use(`/${EModelEndpoint.google}`, google); module.exports = router; diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js index c701369a3..0c2d8e7d8 100644 --- a/api/server/routes/edit/openAI.js +++ b/api/server/routes/edit/openAI.js @@ -1,9 +1,9 @@ const express = require('express'); const router = express.Router(); -const { getResponseSender } = require('../endpoints/schemas'); -const { initializeClient } = require('../endpoints/openAI'); -const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); -const { sendMessage, createOnProgress } = require('../../utils'); +const { getResponseSender } = require('~/server/services/Endpoints'); +const { initializeClient } = require('~/server/services/Endpoints/openAI'); +const { saveMessage, getConvoTitle, getConvo } = require('~/models'); +const { sendMessage, createOnProgress } = require('~/server/utils'); const { handleAbort, createAbortController, @@ -11,7 +11,7 @@ const { setHeaders, validateEndpoint, buildEndpointOption, -} = require('../../middleware'); +} = require('~/server/middleware'); router.post('/abort', handleAbort()); diff --git a/api/server/services/Config/EndpointService.js b/api/server/services/Config/EndpointService.js index a4fb93da4..d06ce8ed0 100644 --- a/api/server/services/Config/EndpointService.js +++ b/api/server/services/Config/EndpointService.js @@ -1,4 +1,4 @@ -const { EModelEndpoint } = require('~/server/routes/endpoints/schemas'); +const { EModelEndpoint } = require('~/server/services/Endpoints'); const { OPENAI_API_KEY: openAIApiKey, @@ -7,7 +7,7 @@ const { CHATGPT_TOKEN: chatGPTToken, BINGAI_TOKEN: bingToken, PLUGINS_USE_AZURE, - PALM_KEY: palmKey, + GOOGLE_KEY: googleKey, } = process.env ?? {}; const useAzurePlugins = !!PLUGINS_USE_AZURE; @@ -26,7 +26,7 @@ module.exports = { azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, - palmKey, + googleKey, [EModelEndpoint.openAI]: isUserProvided(openAIApiKey), [EModelEndpoint.assistant]: isUserProvided(openAIApiKey), [EModelEndpoint.azureOpenAI]: isUserProvided(azureOpenAIApiKey), diff --git a/api/server/services/Config/loadAsyncEndpoints.js b/api/server/services/Config/loadAsyncEndpoints.js index fc5449749..c06e5f1c9 100644 --- a/api/server/services/Config/loadAsyncEndpoints.js +++ b/api/server/services/Config/loadAsyncEndpoints.js @@ -1,6 +1,6 @@ const { availableTools } = require('~/app/clients/tools'); const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs'); -const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, palmKey } = +const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = require('./EndpointService').config; /** @@ -8,7 +8,7 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, pa */ async function loadAsyncEndpoints() { let i = 0; - let key, palmUser; + let key, googleUserProvides; try { key = require('~/data/auth.json'); } catch (e) { @@ -17,8 +17,8 @@ async function loadAsyncEndpoints() { } } - if (palmKey === 'user_provided') { - palmUser = true; + if (googleKey === 'user_provided') { + googleUserProvides = true; if (i <= 1) { i++; } @@ -33,7 +33,7 @@ async function loadAsyncEndpoints() { } const plugins = transformToolsToMap(tools); - const google = key || palmUser ? { userProvide: palmUser } : false; + const google = key || googleUserProvides ? { userProvide: googleUserProvides } : false; const gptPlugins = openAIApiKey || azureOpenAIApiKey diff --git a/api/server/services/Config/loadDefaultEConfig.js b/api/server/services/Config/loadDefaultEConfig.js index 833fcf34b..b4b113b8d 100644 --- a/api/server/services/Config/loadDefaultEConfig.js +++ b/api/server/services/Config/loadDefaultEConfig.js @@ -1,4 +1,4 @@ -const { EModelEndpoint } = require('~/server/routes/endpoints/schemas'); +const { EModelEndpoint } = require('~/server/services/Endpoints'); const loadAsyncEndpoints = require('./loadAsyncEndpoints'); const { config } = require('./EndpointService'); diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index 41a1bac68..7907dd490 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -3,7 +3,7 @@ const { getChatGPTBrowserModels, getAnthropicModels, } = require('~/server/services/ModelService'); -const { EModelEndpoint } = require('~/server/routes/endpoints/schemas'); +const { EModelEndpoint } = require('~/server/services/Endpoints'); const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config; const fitlerAssistantModels = (str) => { @@ -21,7 +21,18 @@ async function loadDefaultModels() { [EModelEndpoint.openAI]: openAI, [EModelEndpoint.azureOpenAI]: azureOpenAI, [EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels), - [EModelEndpoint.google]: ['chat-bison', 'text-bison', 'codechat-bison'], + [EModelEndpoint.google]: [ + 'chat-bison', + 'chat-bison-32k', + 'codechat-bison', + 'codechat-bison-32k', + 'text-bison', + 'text-bison-32k', + 'text-unicorn', + 'code-gecko', + 'code-bison', + 'code-bison-32k', + ], [EModelEndpoint.bingAI]: ['BingAI', 'Sydney'], [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, [EModelEndpoint.gptPlugins]: gptPlugins, diff --git a/api/server/routes/endpoints/anthropic/buildOptions.js b/api/server/services/Endpoints/anthropic/buildOptions.js similarity index 100% rename from api/server/routes/endpoints/anthropic/buildOptions.js rename to api/server/services/Endpoints/anthropic/buildOptions.js diff --git a/api/server/routes/endpoints/anthropic/index.js b/api/server/services/Endpoints/anthropic/index.js similarity index 100% rename from api/server/routes/endpoints/anthropic/index.js rename to api/server/services/Endpoints/anthropic/index.js diff --git a/api/server/routes/endpoints/anthropic/initializeClient.js b/api/server/services/Endpoints/anthropic/initializeClient.js similarity index 90% rename from api/server/routes/endpoints/anthropic/initializeClient.js rename to api/server/services/Endpoints/anthropic/initializeClient.js index 4700da6c8..575a21699 100644 --- a/api/server/routes/endpoints/anthropic/initializeClient.js +++ b/api/server/services/Endpoints/anthropic/initializeClient.js @@ -2,7 +2,7 @@ const { AnthropicClient } = require('~/app'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const initializeClient = async ({ req, res, endpointOption }) => { - const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY } = process.env; + const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY, PROXY } = process.env; const expiresAt = req.body.key; const isUserProvided = ANTHROPIC_API_KEY === 'user_provided'; @@ -21,6 +21,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { req, res, reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null, + proxy: PROXY ?? null, ...endpointOption, }); diff --git a/api/server/services/Endpoints/google/buildOptions.js b/api/server/services/Endpoints/google/buildOptions.js new file mode 100644 index 000000000..0f00bf82d --- /dev/null +++ b/api/server/services/Endpoints/google/buildOptions.js @@ -0,0 +1,16 @@ +const buildOptions = (endpoint, parsedBody) => { + const { examples, modelLabel, promptPrefix, ...rest } = parsedBody; + const endpointOption = { + examples, + endpoint, + modelLabel, + promptPrefix, + modelOptions: { + ...rest, + }, + }; + + return endpointOption; +}; + +module.exports = buildOptions; diff --git a/api/server/services/Endpoints/google/index.js b/api/server/services/Endpoints/google/index.js new file mode 100644 index 000000000..84e4bd597 --- /dev/null +++ b/api/server/services/Endpoints/google/index.js @@ -0,0 +1,8 @@ +const buildOptions = require('./buildOptions'); +const initializeClient = require('./initializeClient'); + +module.exports = { + // addTitle, // todo + buildOptions, + initializeClient, +}; diff --git a/api/server/services/Endpoints/google/initializeClient.js b/api/server/services/Endpoints/google/initializeClient.js new file mode 100644 index 000000000..27eb3e8ef --- /dev/null +++ b/api/server/services/Endpoints/google/initializeClient.js @@ -0,0 +1,35 @@ +const { GoogleClient } = require('~/app'); +const { EModelEndpoint } = require('~/server/services/Endpoints'); +const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); + +const initializeClient = async ({ req, res, endpointOption }) => { + const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, PROXY } = process.env; + const isUserProvided = GOOGLE_KEY === 'user_provided'; + const { key: expiresAt } = req.body; + + let userKey = null; + if (expiresAt && isUserProvided) { + checkUserKeyExpiry( + expiresAt, + 'Your Google key has expired. Please provide your JSON credentials again.', + ); + userKey = await getUserKey({ userId: req.user.id, name: EModelEndpoint.google }); + } + + const apiKey = isUserProvided ? userKey : require('~/data/auth.json'); + + const client = new GoogleClient(apiKey, { + req, + res, + reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null, + proxy: PROXY ?? null, + ...endpointOption, + }); + + return { + client, + apiKey, + }; +}; + +module.exports = initializeClient; diff --git a/api/server/routes/endpoints/gptPlugins/buildOptions.js b/api/server/services/Endpoints/gptPlugins/buildOptions.js similarity index 100% rename from api/server/routes/endpoints/gptPlugins/buildOptions.js rename to api/server/services/Endpoints/gptPlugins/buildOptions.js diff --git a/api/server/routes/endpoints/gptPlugins/index.js b/api/server/services/Endpoints/gptPlugins/index.js similarity index 100% rename from api/server/routes/endpoints/gptPlugins/index.js rename to api/server/services/Endpoints/gptPlugins/index.js diff --git a/api/server/routes/endpoints/gptPlugins/initializeClient.js b/api/server/services/Endpoints/gptPlugins/initializeClient.js similarity index 87% rename from api/server/routes/endpoints/gptPlugins/initializeClient.js rename to api/server/services/Endpoints/gptPlugins/initializeClient.js index 2ab04ec09..4abb2d2de 100644 --- a/api/server/routes/endpoints/gptPlugins/initializeClient.js +++ b/api/server/services/Endpoints/gptPlugins/initializeClient.js @@ -1,7 +1,7 @@ -const { PluginsClient } = require('../../../../app'); -const { isEnabled } = require('../../../utils'); -const { getAzureCredentials } = require('../../../../utils'); -const { getUserKey, checkUserKeyExpiry } = require('../../../services/UserService'); +const { PluginsClient } = require('~/app'); +const { isEnabled } = require('~/server/utils'); +const { getAzureCredentials } = require('~/utils'); +const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const initializeClient = async ({ req, res, endpointOption }) => { const { diff --git a/api/server/routes/endpoints/gptPlugins/initializeClient.spec.js b/api/server/services/Endpoints/gptPlugins/initializeClient.spec.js similarity index 96% rename from api/server/routes/endpoints/gptPlugins/initializeClient.spec.js rename to api/server/services/Endpoints/gptPlugins/initializeClient.spec.js index b8d76ced2..5b772209c 100644 --- a/api/server/routes/endpoints/gptPlugins/initializeClient.spec.js +++ b/api/server/services/Endpoints/gptPlugins/initializeClient.spec.js @@ -1,12 +1,12 @@ // gptPlugins/initializeClient.spec.js +const { PluginsClient } = require('~/app'); const initializeClient = require('./initializeClient'); -const { PluginsClient } = require('../../../../app'); -const { getUserKey } = require('../../../services/UserService'); +const { getUserKey } = require('../../UserService'); // Mock getUserKey since it's the only function we want to mock -jest.mock('../../../services/UserService', () => ({ +jest.mock('~/server/services/UserService', () => ({ getUserKey: jest.fn(), - checkUserKeyExpiry: jest.requireActual('../../../services/UserService').checkUserKeyExpiry, + checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry, })); describe('gptPlugins/initializeClient', () => { diff --git a/api/server/services/Endpoints/index.js b/api/server/services/Endpoints/index.js new file mode 100644 index 000000000..bdb884c7f --- /dev/null +++ b/api/server/services/Endpoints/index.js @@ -0,0 +1,5 @@ +const schemas = require('./schemas'); + +module.exports = { + ...schemas, +}; diff --git a/api/server/routes/endpoints/openAI/addTitle.js b/api/server/services/Endpoints/openAI/addTitle.js similarity index 100% rename from api/server/routes/endpoints/openAI/addTitle.js rename to api/server/services/Endpoints/openAI/addTitle.js diff --git a/api/server/routes/endpoints/openAI/buildOptions.js b/api/server/services/Endpoints/openAI/buildOptions.js similarity index 100% rename from api/server/routes/endpoints/openAI/buildOptions.js rename to api/server/services/Endpoints/openAI/buildOptions.js diff --git a/api/server/routes/endpoints/openAI/index.js b/api/server/services/Endpoints/openAI/index.js similarity index 100% rename from api/server/routes/endpoints/openAI/index.js rename to api/server/services/Endpoints/openAI/index.js diff --git a/api/server/routes/endpoints/openAI/initializeClient.js b/api/server/services/Endpoints/openAI/initializeClient.js similarity index 100% rename from api/server/routes/endpoints/openAI/initializeClient.js rename to api/server/services/Endpoints/openAI/initializeClient.js diff --git a/api/server/routes/endpoints/openAI/initializeClient.spec.js b/api/server/services/Endpoints/openAI/initializeClient.spec.js similarity index 96% rename from api/server/routes/endpoints/openAI/initializeClient.spec.js rename to api/server/services/Endpoints/openAI/initializeClient.spec.js index 731d42e06..03f567744 100644 --- a/api/server/routes/endpoints/openAI/initializeClient.spec.js +++ b/api/server/services/Endpoints/openAI/initializeClient.spec.js @@ -1,11 +1,11 @@ +const { OpenAIClient } = require('~/app'); const initializeClient = require('./initializeClient'); -const { OpenAIClient } = require('../../../../app'); -const { getUserKey } = require('../../../services/UserService'); +const { getUserKey } = require('~/server/services/UserService'); // Mock getUserKey since it's the only function we want to mock -jest.mock('../../../services/UserService', () => ({ +jest.mock('~/server/services/UserService', () => ({ getUserKey: jest.fn(), - checkUserKeyExpiry: jest.requireActual('../../../services/UserService').checkUserKeyExpiry, + checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry, })); describe('initializeClient', () => { diff --git a/api/server/routes/endpoints/schemas.js b/api/server/services/Endpoints/schemas.js similarity index 87% rename from api/server/routes/endpoints/schemas.js rename to api/server/services/Endpoints/schemas.js index 84ad36b33..4f786feab 100644 --- a/api/server/routes/endpoints/schemas.js +++ b/api/server/services/Endpoints/schemas.js @@ -18,10 +18,44 @@ const alternateName = { [EModelEndpoint.bingAI]: 'Bing', [EModelEndpoint.chatGPTBrowser]: 'ChatGPT', [EModelEndpoint.gptPlugins]: 'Plugins', - [EModelEndpoint.google]: 'PaLM', + [EModelEndpoint.google]: 'Google', [EModelEndpoint.anthropic]: 'Anthropic', }; +const endpointSettings = { + [EModelEndpoint.google]: { + model: { + default: 'chat-bison', + }, + maxOutputTokens: { + min: 1, + max: 2048, + step: 1, + default: 1024, + }, + temperature: { + min: 0, + max: 1, + step: 0.01, + default: 0.2, + }, + topP: { + min: 0, + max: 1, + step: 0.01, + default: 0.8, + }, + topK: { + min: 1, + max: 40, + step: 0.01, + default: 40, + }, + }, +}; + +const google = endpointSettings[EModelEndpoint.google]; + const supportsFiles = { [EModelEndpoint.openAI]: true, [EModelEndpoint.assistant]: true, @@ -158,22 +192,24 @@ const googleSchema = tConversationSchema }) .transform((obj) => ({ ...obj, - model: obj.model ?? 'chat-bison', + model: obj.model ?? google.model.default, modelLabel: obj.modelLabel ?? null, promptPrefix: obj.promptPrefix ?? null, - temperature: obj.temperature ?? 0.2, - maxOutputTokens: obj.maxOutputTokens ?? 1024, - topP: obj.topP ?? 0.95, - topK: obj.topK ?? 40, + examples: obj.examples ?? [{ input: { content: '' }, output: { content: '' } }], + temperature: obj.temperature ?? google.temperature.default, + maxOutputTokens: obj.maxOutputTokens ?? google.maxOutputTokens.default, + topP: obj.topP ?? google.topP.default, + topK: obj.topK ?? google.topK.default, })) .catch(() => ({ - model: 'chat-bison', + model: google.model.default, modelLabel: null, promptPrefix: null, - temperature: 0.2, - maxOutputTokens: 1024, - topP: 0.95, - topK: 40, + examples: [{ input: { content: '' }, output: { content: '' } }], + temperature: google.temperature.default, + maxOutputTokens: google.maxOutputTokens.default, + topP: google.topP.default, + topK: google.topK.default, })); const bingAISchema = tConversationSchema @@ -385,7 +421,13 @@ const getResponseSender = (endpointOption) => { } if (endpoint === EModelEndpoint.google) { - return modelLabel ?? 'PaLM2'; + if (modelLabel) { + return modelLabel; + } else if (model && model.includes('code')) { + return 'Codey'; + } + + return 'PaLM2'; } return ''; @@ -399,4 +441,5 @@ module.exports = { openAIModels, visionModels, alternateName, + endpointSettings, }; diff --git a/api/server/services/Files/images/validate.js b/api/server/services/Files/images/validate.js index acffedd60..0e1965749 100644 --- a/api/server/services/Files/images/validate.js +++ b/api/server/services/Files/images/validate.js @@ -1,4 +1,4 @@ -const { visionModels } = require('~/server/routes/endpoints/schemas'); +const { visionModels } = require('~/server/services/Endpoints'); function validateVisionModel(model) { if (!model) { diff --git a/api/typedefs.js b/api/typedefs.js index d1796f805..c12254e32 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -247,7 +247,7 @@ * @property {string} azureOpenAIApiKey - The API key for Azure OpenAI. * @property {boolean} useAzurePlugins - Flag to indicate if Azure plugins are used. * @property {boolean} userProvidedOpenAI - Flag to indicate if OpenAI API key is user provided. - * @property {string} palmKey - The Palm key. + * @property {string} googleKey - The Palm key. * @property {boolean|{userProvide: boolean}} [openAI] - Flag to indicate if OpenAI endpoint is user provided, or its configuration. * @property {boolean|{userProvide: boolean}} [assistant] - Flag to indicate if Assistant endpoint is user provided, or its configuration. * @property {boolean|{userProvide: boolean}} [azureOpenAI] - Flag to indicate if Azure OpenAI endpoint is user provided, or its configuration. diff --git a/api/utils/tokens.js b/api/utils/tokens.js index c956d53ce..eeca06639 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -1,3 +1,5 @@ +const { EModelEndpoint } = require('~/server/services/Endpoints'); + const models = [ 'text-davinci-003', 'text-davinci-002', @@ -39,20 +41,37 @@ const models = [ // Order is important here: by model series and context size (gpt-4 then gpt-3, ascending) const maxTokensMap = { - 'gpt-4': 8191, - 'gpt-4-0613': 8191, - 'gpt-4-32k': 32767, - 'gpt-4-32k-0314': 32767, - 'gpt-4-32k-0613': 32767, - 'gpt-3.5-turbo': 4095, - 'gpt-3.5-turbo-0613': 4095, - 'gpt-3.5-turbo-0301': 4095, - 'gpt-3.5-turbo-16k': 15999, - 'gpt-3.5-turbo-16k-0613': 15999, - 'gpt-3.5-turbo-1106': 16380, // -5 from max - 'gpt-4-1106': 127995, // -5 from max - 'claude-2.1': 200000, - 'claude-': 100000, + [EModelEndpoint.openAI]: { + 'gpt-4': 8191, + 'gpt-4-0613': 8191, + 'gpt-4-32k': 32767, + 'gpt-4-32k-0314': 32767, + 'gpt-4-32k-0613': 32767, + 'gpt-3.5-turbo': 4095, + 'gpt-3.5-turbo-0613': 4095, + 'gpt-3.5-turbo-0301': 4095, + 'gpt-3.5-turbo-16k': 15999, + 'gpt-3.5-turbo-16k-0613': 15999, + 'gpt-3.5-turbo-1106': 16380, // -5 from max + 'gpt-4-1106': 127995, // -5 from max + }, + [EModelEndpoint.google]: { + /* Max I/O is 32k combined, so -1000 to leave room for response */ + 'text-bison-32k': 31000, + 'chat-bison-32k': 31000, + 'code-bison-32k': 31000, + 'codechat-bison-32k': 31000, + /* Codey, -5 from max: 6144 */ + 'code-': 6139, + 'codechat-': 6139, + /* PaLM2, -5 from max: 8192 */ + 'text-': 8187, + 'chat-': 8187, + }, + [EModelEndpoint.anthropic]: { + 'claude-2.1': 200000, + 'claude-': 100000, + }, }; /** @@ -60,6 +79,7 @@ const maxTokensMap = { * it searches for partial matches within the model name, checking keys in reverse order. * * @param {string} modelName - The name of the model to look up. + * @param {string} endpoint - The endpoint (default is 'openAI'). * @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found. * * @example @@ -67,19 +87,24 @@ const maxTokensMap = { * getModelMaxTokens('gpt-4-32k-unknown'); // Returns 32767 * getModelMaxTokens('unknown-model'); // Returns undefined */ -function getModelMaxTokens(modelName) { +function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) { if (typeof modelName !== 'string') { return undefined; } - if (maxTokensMap[modelName]) { - return maxTokensMap[modelName]; + const tokensMap = maxTokensMap[endpoint]; + if (!tokensMap) { + return undefined; } - const keys = Object.keys(maxTokensMap); + if (tokensMap[modelName]) { + return tokensMap[modelName]; + } + + const keys = Object.keys(tokensMap); for (let i = keys.length - 1; i >= 0; i--) { if (modelName.includes(keys[i])) { - return maxTokensMap[keys[i]]; + return tokensMap[keys[i]]; } } @@ -91,6 +116,7 @@ function getModelMaxTokens(modelName) { * it searches for partial matches within the model name, checking keys in reverse order. * * @param {string} modelName - The name of the model to look up. + * @param {string} endpoint - The endpoint (default is 'openAI'). * @returns {string|undefined} The model name key for the given model; returns input if no match is found and is string. * * @example @@ -98,16 +124,21 @@ function getModelMaxTokens(modelName) { * matchModelName('gpt-4-32k-unknown'); // Returns 'gpt-4-32k' * matchModelName('unknown-model'); // Returns undefined */ -function matchModelName(modelName) { +function matchModelName(modelName, endpoint = EModelEndpoint.openAI) { if (typeof modelName !== 'string') { return undefined; } - if (maxTokensMap[modelName]) { + const tokensMap = maxTokensMap[endpoint]; + if (!tokensMap) { return modelName; } - const keys = Object.keys(maxTokensMap); + if (tokensMap[modelName]) { + return modelName; + } + + const keys = Object.keys(tokensMap); for (let i = keys.length - 1; i >= 0; i--) { if (modelName.includes(keys[i])) { return keys[i]; diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index 8b5150770..2430590c3 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -1,16 +1,23 @@ +const { EModelEndpoint } = require('~/server/services/Endpoints'); const { getModelMaxTokens, matchModelName, maxTokensMap } = require('./tokens'); describe('getModelMaxTokens', () => { test('should return correct tokens for exact match', () => { - expect(getModelMaxTokens('gpt-4-32k-0613')).toBe(maxTokensMap['gpt-4-32k-0613']); + expect(getModelMaxTokens('gpt-4-32k-0613')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-32k-0613'], + ); }); test('should return correct tokens for partial match', () => { - expect(getModelMaxTokens('gpt-4-32k-unknown')).toBe(maxTokensMap['gpt-4-32k']); + expect(getModelMaxTokens('gpt-4-32k-unknown')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-32k'], + ); }); test('should return correct tokens for partial match (OpenRouter)', () => { - expect(getModelMaxTokens('openai/gpt-4-32k')).toBe(maxTokensMap['gpt-4-32k']); + expect(getModelMaxTokens('openai/gpt-4-32k')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-32k'], + ); }); test('should return undefined for no match', () => { @@ -19,12 +26,14 @@ describe('getModelMaxTokens', () => { test('should return correct tokens for another exact match', () => { expect(getModelMaxTokens('gpt-3.5-turbo-16k-0613')).toBe( - maxTokensMap['gpt-3.5-turbo-16k-0613'], + maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-16k-0613'], ); }); test('should return correct tokens for another partial match', () => { - expect(getModelMaxTokens('gpt-3.5-turbo-unknown')).toBe(maxTokensMap['gpt-3.5-turbo']); + expect(getModelMaxTokens('gpt-3.5-turbo-unknown')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo'], + ); }); test('should return undefined for undefined input', () => { @@ -41,26 +50,34 @@ describe('getModelMaxTokens', () => { // 11/06 Update test('should return correct tokens for gpt-3.5-turbo-1106 exact match', () => { - expect(getModelMaxTokens('gpt-3.5-turbo-1106')).toBe(maxTokensMap['gpt-3.5-turbo-1106']); + expect(getModelMaxTokens('gpt-3.5-turbo-1106')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-1106'], + ); }); test('should return correct tokens for gpt-4-1106 exact match', () => { - expect(getModelMaxTokens('gpt-4-1106')).toBe(maxTokensMap['gpt-4-1106']); + expect(getModelMaxTokens('gpt-4-1106')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-4-1106']); }); test('should return correct tokens for gpt-3.5-turbo-1106 partial match', () => { expect(getModelMaxTokens('something-/gpt-3.5-turbo-1106')).toBe( - maxTokensMap['gpt-3.5-turbo-1106'], + maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-1106'], ); expect(getModelMaxTokens('gpt-3.5-turbo-1106/something-/')).toBe( - maxTokensMap['gpt-3.5-turbo-1106'], + maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-1106'], ); }); test('should return correct tokens for gpt-4-1106 partial match', () => { - expect(getModelMaxTokens('gpt-4-1106/something')).toBe(maxTokensMap['gpt-4-1106']); - expect(getModelMaxTokens('gpt-4-1106-preview')).toBe(maxTokensMap['gpt-4-1106']); - expect(getModelMaxTokens('gpt-4-1106-vision-preview')).toBe(maxTokensMap['gpt-4-1106']); + expect(getModelMaxTokens('gpt-4-1106/something')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-1106'], + ); + expect(getModelMaxTokens('gpt-4-1106-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-1106'], + ); + expect(getModelMaxTokens('gpt-4-1106-vision-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-1106'], + ); }); test('should return correct tokens for Anthropic models', () => { @@ -74,13 +91,36 @@ describe('getModelMaxTokens', () => { 'claude-instant-1-100k', ]; - const claude21MaxTokens = maxTokensMap['claude-2.1']; - const claudeMaxTokens = maxTokensMap['claude-']; + const claudeMaxTokens = maxTokensMap[EModelEndpoint.anthropic]['claude-']; + const claude21MaxTokens = maxTokensMap[EModelEndpoint.anthropic]['claude-2.1']; models.forEach((model) => { const expectedTokens = model === 'claude-2.1' ? claude21MaxTokens : claudeMaxTokens; - expect(getModelMaxTokens(model)).toEqual(expectedTokens); + expect(getModelMaxTokens(model, EModelEndpoint.anthropic)).toEqual(expectedTokens); }); }); + + // Tests for Google models + test('should return correct tokens for exact match - Google models', () => { + expect(getModelMaxTokens('text-bison-32k', EModelEndpoint.google)).toBe( + maxTokensMap[EModelEndpoint.google]['text-bison-32k'], + ); + expect(getModelMaxTokens('codechat-bison-32k', EModelEndpoint.google)).toBe( + maxTokensMap[EModelEndpoint.google]['codechat-bison-32k'], + ); + }); + + test('should return undefined for no match - Google models', () => { + expect(getModelMaxTokens('unknown-google-model', EModelEndpoint.google)).toBeUndefined(); + }); + + test('should return correct tokens for partial match - Google models', () => { + expect(getModelMaxTokens('code-', EModelEndpoint.google)).toBe( + maxTokensMap[EModelEndpoint.google]['code-'], + ); + expect(getModelMaxTokens('chat-', EModelEndpoint.google)).toBe( + maxTokensMap[EModelEndpoint.google]['chat-'], + ); + }); }); describe('matchModelName', () => { @@ -122,4 +162,21 @@ describe('matchModelName', () => { expect(matchModelName('gpt-4-1106-preview')).toBe('gpt-4-1106'); expect(matchModelName('gpt-4-1106-vision-preview')).toBe('gpt-4-1106'); }); + + // Tests for Google models + it('should return the exact model name if it exists in maxTokensMap - Google models', () => { + expect(matchModelName('text-bison-32k', EModelEndpoint.google)).toBe('text-bison-32k'); + expect(matchModelName('codechat-bison-32k', EModelEndpoint.google)).toBe('codechat-bison-32k'); + }); + + it('should return the input model name if no match is found - Google models', () => { + expect(matchModelName('unknown-google-model', EModelEndpoint.google)).toBe( + 'unknown-google-model', + ); + }); + + it('should return the closest matching key for partial matches - Google models', () => { + expect(matchModelName('code-', EModelEndpoint.google)).toBe('code-'); + expect(matchModelName('chat-', EModelEndpoint.google)).toBe('chat-'); + }); }); diff --git a/client/src/components/Chat/Menus/Endpoints/Icons.tsx b/client/src/components/Chat/Menus/Endpoints/Icons.tsx index 90780f3bd..56aed93f1 100644 --- a/client/src/components/Chat/Menus/Endpoints/Icons.tsx +++ b/client/src/components/Chat/Menus/Endpoints/Icons.tsx @@ -5,7 +5,7 @@ import { AnthropicIcon, AzureMinimalIcon, BingAIMinimalIcon, - PaLMinimalIcon, + GoogleMinimalIcon, LightningIcon, } from '~/components/svg'; import { cn } from '~/utils'; @@ -16,7 +16,7 @@ export const icons = { [EModelEndpoint.gptPlugins]: MinimalPlugin, [EModelEndpoint.anthropic]: AnthropicIcon, [EModelEndpoint.chatGPTBrowser]: LightningIcon, - [EModelEndpoint.google]: PaLMinimalIcon, + [EModelEndpoint.google]: GoogleMinimalIcon, [EModelEndpoint.bingAI]: BingAIMinimalIcon, [EModelEndpoint.assistant]: ({ className = '' }) => ( diff --git a/client/src/components/Chat/Messages/Content/EditMessage.tsx b/client/src/components/Chat/Messages/Content/EditMessage.tsx index 141fbce44..1a965e0a0 100644 --- a/client/src/components/Chat/Messages/Content/EditMessage.tsx +++ b/client/src/components/Chat/Messages/Content/EditMessage.tsx @@ -1,5 +1,5 @@ import { useRef } from 'react'; -import { useUpdateMessageMutation } from 'librechat-data-provider'; +import { useUpdateMessageMutation, EModelEndpoint } from 'librechat-data-provider'; import Container from '~/components/Messages/Content/Container'; import { useChatContext } from '~/Providers'; import type { TEditProps } from '~/common'; @@ -18,6 +18,7 @@ const EditMessage = ({ const textEditor = useRef(null); const { conversationId, parentMessageId, messageId } = message; + const { endpoint } = conversation ?? { endpoint: null }; const updateMessageMutation = useUpdateMessageMutation(conversationId ?? ''); const localize = useLocalize(); @@ -94,7 +95,9 @@ const EditMessage = ({
@@ -92,16 +90,18 @@ export default function Settings({ conversation, setOption, models, readonly }:
setTemperature(value ?? 0.2)} - max={1} - min={0} - step={0.01} + onChange={(value) => setTemperature(value ?? google.temperature.default)} + max={google.temperature.max} + min={google.temperature.min} + step={google.temperature.step} controls={false} className={cn( defaultTextProps, @@ -114,18 +114,18 @@ export default function Settings({ conversation, setOption, models, readonly }:
setTemperature(value[0])} - doubleClickHandler={() => setTemperature(0.2)} - max={1} - min={0} - step={0.01} + doubleClickHandler={() => setTemperature(google.temperature.default)} + max={google.temperature.max} + min={google.temperature.min} + step={google.temperature.step} className="flex h-4 w-full" />
- {!codeChat && ( + {!isTextModel && ( <> @@ -133,17 +133,17 @@ export default function Settings({ conversation, setOption, models, readonly }: setTopP(value ?? '0.95')} - max={1} - min={0} - step={0.01} + onChange={(value) => setTopP(value ?? google.topP.default)} + max={google.topP.max} + min={google.topP.min} + step={google.topP.step} controls={false} className={cn( defaultTextProps, @@ -156,12 +156,12 @@ export default function Settings({ conversation, setOption, models, readonly }:
setTopP(value[0])} - doubleClickHandler={() => setTopP(0.95)} - max={1} - min={0} - step={0.01} + doubleClickHandler={() => setTopP(google.topP.default)} + max={google.topP.max} + min={google.topP.min} + step={google.topP.step} className="flex h-4 w-full" /> @@ -174,17 +174,17 @@ export default function Settings({ conversation, setOption, models, readonly }: setTopK(value ?? 40)} - max={40} - min={1} - step={0.01} + onChange={(value) => setTopK(value ?? google.topK.default)} + max={google.topK.max} + min={google.topK.min} + step={google.topK.step} controls={false} className={cn( defaultTextProps, @@ -197,12 +197,12 @@ export default function Settings({ conversation, setOption, models, readonly }: setTopK(value[0])} - doubleClickHandler={() => setTopK(40)} - max={40} - min={1} - step={0.01} + doubleClickHandler={() => setTopK(google.topK.default)} + max={google.topK.max} + min={google.topK.min} + step={google.topK.step} className="flex h-4 w-full" /> @@ -216,17 +216,17 @@ export default function Settings({ conversation, setOption, models, readonly }: setMaxOutputTokens(value ?? 1024)} - max={1024} - min={1} - step={1} + onChange={(value) => setMaxOutputTokens(value ?? google.maxOutputTokens.default)} + max={google.maxOutputTokens.max} + min={google.maxOutputTokens.min} + step={google.maxOutputTokens.step} controls={false} className={cn( defaultTextProps, @@ -239,12 +239,12 @@ export default function Settings({ conversation, setOption, models, readonly }: setMaxOutputTokens(value[0])} - doubleClickHandler={() => setMaxOutputTokens(1024)} - max={1024} - min={1} - step={1} + doubleClickHandler={() => setMaxOutputTokens(google.maxOutputTokens.default)} + max={google.maxOutputTokens.max} + min={google.maxOutputTokens.min} + step={google.maxOutputTokens.step} className="flex h-4 w-full" /> diff --git a/client/src/components/Endpoints/Settings/MultiView/GoogleSettings.tsx b/client/src/components/Endpoints/Settings/MultiView/GoogleSettings.tsx index 27b0e7623..19cce9bd7 100644 --- a/client/src/components/Endpoints/Settings/MultiView/GoogleSettings.tsx +++ b/client/src/components/Endpoints/Settings/MultiView/GoogleSettings.tsx @@ -16,7 +16,7 @@ export default function GoogleView({ conversation, models, isPreset = false }) { const { showExamples, isCodeChat } = optionSettings; return showExamples && !isCodeChat ? ( + + + ); +} diff --git a/client/src/components/svg/GoogleMinimalIcon.tsx b/client/src/components/svg/GoogleMinimalIcon.tsx new file mode 100644 index 000000000..903ea044f --- /dev/null +++ b/client/src/components/svg/GoogleMinimalIcon.tsx @@ -0,0 +1,15 @@ +import { cn } from '~/utils'; +export default function GoogleMinimalIcon({ className = '' }: { className?: string }) { + return ( + + + + ); +} diff --git a/client/src/components/svg/PaLMIcon.tsx b/client/src/components/svg/PaLMIcon.tsx new file mode 100644 index 000000000..18dd93ce3 --- /dev/null +++ b/client/src/components/svg/PaLMIcon.tsx @@ -0,0 +1,50 @@ +export default function PaLMIcon({ + size = 25, + className = '', +}: { + size?: number; + className?: string; +}) { + return ( + + + + + + + + + + ); +} diff --git a/client/src/components/svg/PaLMinimalIcon.tsx b/client/src/components/svg/PaLMinimalIcon.tsx index d69d24cc4..1156c0594 100644 --- a/client/src/components/svg/PaLMinimalIcon.tsx +++ b/client/src/components/svg/PaLMinimalIcon.tsx @@ -1,6 +1,5 @@ -import React from 'react'; - -export default function PaLMinimalIcon() { +import { cn } from '~/utils'; +export default function PaLMinimalIcon({ className = '' }: { className?: string }) { return ( e === endpoint); const continueSupported = diff --git a/client/src/hooks/useGenerationsByLatest.ts b/client/src/hooks/useGenerationsByLatest.ts index 35eaa87e6..acbb3baa3 100644 --- a/client/src/hooks/useGenerationsByLatest.ts +++ b/client/src/hooks/useGenerationsByLatest.ts @@ -18,12 +18,12 @@ export default function useGenerationsByLatest({ }: TUseGenerations) { const { error, messageId, searchResult, finish_reason, isCreatedByUser } = message ?? {}; const isEditableEndpoint = !![ - EModelEndpoint.azureOpenAI, EModelEndpoint.openAI, + EModelEndpoint.google, EModelEndpoint.assistant, + EModelEndpoint.anthropic, EModelEndpoint.gptPlugins, - EModelEndpoint.anthropic, - EModelEndpoint.anthropic, + EModelEndpoint.azureOpenAI, ].find((e) => e === endpoint); const continueSupported = diff --git a/client/src/localization/languages/Ar.tsx b/client/src/localization/languages/Ar.tsx index 5bda94c6a..85d8f5de5 100644 --- a/client/src/localization/languages/Ar.tsx +++ b/client/src/localization/languages/Ar.tsx @@ -133,7 +133,7 @@ export default { 'Top-k يغير كيفية اختيار النموذج للرموز للإخراج. top-k من 1 يعني أن الرمز المحدد هو الأكثر احتمالية بين جميع الرموز في مفردات النموذج (يسمى أيضًا الترميز الجشعي)، بينما top-k من 3 يعني أن الرمز التالي يتم اختياره من بين الرموز الثلاثة الأكثر احتمالية (باستخدام الحرارة).', com_endpoint_google_maxoutputtokens: 'الحد الأقصى لعدد الرموز التي يمكن إنشاؤها في الرد. حدد قيمة أقل للردود الأقصر وقيمة أعلى للردود الأطول.', - com_endpoint_google_custom_name_placeholder: 'قم بتعيين اسم مخصص لـ PaLM2', + com_endpoint_google_custom_name_placeholder: 'قم بتعيين اسم مخصص لـ Google', com_endpoint_prompt_prefix_placeholder: 'قم بتعيين تعليمات مخصصة أو سياق. يتم تجاهله إذا كان فارغًا.', com_endpoint_custom_name: 'اسم مخصص', diff --git a/client/src/localization/languages/Br.tsx b/client/src/localization/languages/Br.tsx index 658508800..830b9e8fc 100644 --- a/client/src/localization/languages/Br.tsx +++ b/client/src/localization/languages/Br.tsx @@ -133,7 +133,7 @@ export default { 'Top-k muda como o modelo seleciona tokens para a saída. Um top-k de 1 significa que o token selecionado é o mais provável entre todos os tokens no vocabulário do modelo (também chamado de decodificação gananciosa), enquanto um top-k de 3 significa que o próximo token é selecionado entre os 3 tokens mais prováveis (usando temperatura).', com_endpoint_google_maxoutputtokens: 'Número máximo de tokens que podem ser gerados na resposta. Especifique um valor menor para respostas mais curtas e um valor maior para respostas mais longas.', - com_endpoint_google_custom_name_placeholder: 'Defina um nome personalizado para o PaLM2', + com_endpoint_google_custom_name_placeholder: 'Defina um nome personalizado para o Google', com_endpoint_prompt_prefix_placeholder: 'Defina instruções ou contexto personalizados. Ignorado se vazio.', com_endpoint_custom_name: 'Nome Personalizado', diff --git a/client/src/localization/languages/De.tsx b/client/src/localization/languages/De.tsx index a472c0869..7abf98e87 100644 --- a/client/src/localization/languages/De.tsx +++ b/client/src/localization/languages/De.tsx @@ -107,7 +107,7 @@ export default { 'Top-k ändert, wie das Modell Token für die Ausgabe auswählt. Ein Top-k von 1 bedeutet, dass das ausgewählte Token das wahrscheinlichste unter allen Token im Vokabular des Modells ist (auch gierige Dekodierung genannt), während ein Top-k von 3 bedeutet, dass das nächste Token aus den drei wahrscheinlichsten Token ausgewählt wird (unter Verwendung der Temperatur).', com_endpoint_google_maxoutputtokens: 'Maximale Anzahl von Token, die in der Antwort erzeugt werden können. Geben Sie einen niedrigeren Wert für kürzere Antworten und einen höheren Wert für längere Antworten an.', - com_endpoint_google_custom_name_placeholder: 'Benutzerdefinierter Name für PaLM2', + com_endpoint_google_custom_name_placeholder: 'Benutzerdefinierter Name für Google', com_endpoint_google_prompt_prefix_placeholder: 'Benutzerdefinierte Anweisungen oder Kontext festlegen. Wird ignoriert, wenn leer.', com_endpoint_custom_name: 'Benutzerdefinierter Name', diff --git a/client/src/localization/languages/Eng.tsx b/client/src/localization/languages/Eng.tsx index 2056c8f9f..8b3db4a56 100644 --- a/client/src/localization/languages/Eng.tsx +++ b/client/src/localization/languages/Eng.tsx @@ -137,7 +137,7 @@ export default { 'Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens in the model\'s vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature).', com_endpoint_google_maxoutputtokens: ' Maximum number of tokens that can be generated in the response. Specify a lower value for shorter responses and a higher value for longer responses.', - com_endpoint_google_custom_name_placeholder: 'Set a custom name for PaLM2', + com_endpoint_google_custom_name_placeholder: 'Set a custom name for Google', com_endpoint_prompt_prefix_placeholder: 'Set custom instructions or context. Ignored if empty.', com_endpoint_custom_name: 'Custom Name', com_endpoint_prompt_prefix: 'Prompt Prefix', diff --git a/client/src/localization/languages/Es.tsx b/client/src/localization/languages/Es.tsx index 95aa75828..0fafba05a 100644 --- a/client/src/localization/languages/Es.tsx +++ b/client/src/localization/languages/Es.tsx @@ -138,7 +138,7 @@ export default { 'Establece instrucciones o contexto personalizado. Ignorado si está vacío.', com_endpoint_google_maxoutputtokens: 'Número máximo de tokens que se pueden generar en la respuesta. Especifica un valor menor para respuestas más cortas y un valor mayor para respuestas más largas.', - com_endpoint_google_custom_name_placeholder: 'Establece un nombre personalizado para PaLM2', + com_endpoint_google_custom_name_placeholder: 'Establece un nombre personalizado para Google', com_endpoint_prompt_prefix_placeholder: 'Establece instrucciones o contexto personalizados. Se ignora si está vacío.', com_endpoint_custom_name: 'Nombre personalizado', diff --git a/client/src/localization/languages/Fr.tsx b/client/src/localization/languages/Fr.tsx index 3d46067c8..a97aea8c4 100644 --- a/client/src/localization/languages/Fr.tsx +++ b/client/src/localization/languages/Fr.tsx @@ -140,7 +140,7 @@ export default { 'Top-k change la façon dont le modèle sélectionne les jetons pour la sortie. Un top-k de 1 signifie que le jeton sélectionné est le plus probable parmi tous les jetons du vocabulaire du modèle (également appelé décodage glouton), tandis qu\'un top-k de 3 signifie que le jeton suivant est sélectionné parmi les 3 jetons les plus probables (en utilisant la température).', com_endpoint_google_maxoutputtokens: 'Nombre maximum de jetons qui peuvent être générés dans la réponse. Spécifiez une valeur plus faible pour des réponses plus courtes et une valeur plus élevée pour des réponses plus longues.', - com_endpoint_google_custom_name_placeholder: 'Définir un nom personnalisé pour PaLM2', + com_endpoint_google_custom_name_placeholder: 'Définir un nom personnalisé pour Google', com_endpoint_google_prompt_prefix_placeholder: 'Définir des instructions ou un contexte personnalisés. Ignoré si vide.', com_endpoint_custom_name: 'Nom personnalisé', diff --git a/client/src/localization/languages/It.tsx b/client/src/localization/languages/It.tsx index a57f8bdc8..9c274f541 100644 --- a/client/src/localization/languages/It.tsx +++ b/client/src/localization/languages/It.tsx @@ -135,7 +135,7 @@ export default { 'Top-k cambia come il modello seleziona i token per l\'output. Un top-k di 1 significa che il token selezionato è il più probabile tra tutti i token nel vocabolario del modello (anche chiamato decodifica greedy), mentre un top-k di 3 significa che il token successivo è selezionato tra i 3 token più probabili (usando temperature).', com_endpoint_google_maxoutputtokens: 'Numero massimo di token che possono essere generati nella risposta. Specifica un valore più basso per risposte più corte e un valore più alto per risposte più lunghe.', - com_endpoint_google_custom_name_placeholder: 'Imposta un nome personalizzato per PaLM2', + com_endpoint_google_custom_name_placeholder: 'Imposta un nome personalizzato per Google', com_endpoint_prompt_prefix_placeholder: 'Imposta istruzioni o contesto personalizzati. Ignorato se vuoto.', com_endpoint_custom_name: 'Nome personalizzato', diff --git a/client/src/localization/languages/Jp.tsx b/client/src/localization/languages/Jp.tsx index b6de73da4..fca0eaa32 100644 --- a/client/src/localization/languages/Jp.tsx +++ b/client/src/localization/languages/Jp.tsx @@ -133,7 +133,7 @@ export default { 'Top-k はモデルがトークンをどのように選択して出力するかを変更します。top-kが1の場合はモデルの語彙に含まれるすべてのトークンの中で最も確率が高い1つが選択されます(greedy decodingと呼ばれている)。top-kが3の場合は上位3つのトークンの中から選択されます。(temperatureを使用)', com_endpoint_google_maxoutputtokens: ' 生成されるレスポンスの最大トークン数。短いレスポンスには低い値を、長いレスポンスには高い値を指定します。', - com_endpoint_google_custom_name_placeholder: 'PaLM2のカスタム名を設定する', + com_endpoint_google_custom_name_placeholder: 'Googleのカスタム名を設定する', com_endpoint_prompt_prefix_placeholder: 'custom instructions か context を設定する。空の場合は無視されます。', com_endpoint_custom_name: 'プリセット名', diff --git a/client/src/localization/languages/Ko.tsx b/client/src/localization/languages/Ko.tsx index 3ad10a1c8..490ca4008 100644 --- a/client/src/localization/languages/Ko.tsx +++ b/client/src/localization/languages/Ko.tsx @@ -124,7 +124,7 @@ export default { 'Top-k는 모델이 출력에 사용할 토큰을 선택하는 방식을 변경합니다. top-k가 1인 경우 모델의 어휘 중 가장 확률이 높은 토큰이 선택됩니다(greedy decoding). top-k가 3인 경우 다음 토큰은 가장 확률이 높은 3개의 토큰 중에서 선택됩니다(temperature 사용).', com_endpoint_google_maxoutputtokens: '응답에서 생성할 수 있는 최대 토큰 수입니다. 짧은 응답에는 낮은 값을, 긴 응답에는 높은 값을 지정하세요.', - com_endpoint_google_custom_name_placeholder: 'PaLM2에 대한 사용자 정의 이름 설정', + com_endpoint_google_custom_name_placeholder: 'Google에 대한 사용자 정의 이름 설정', com_endpoint_prompt_prefix_placeholder: '사용자 정의 지시사항 또는 컨텍스트를 설정하세요. 비어 있으면 무시됩니다.', com_endpoint_custom_name: '사용자 정의 이름', diff --git a/client/src/localization/languages/Nl.tsx b/client/src/localization/languages/Nl.tsx index c0b53dd6e..73e776bec 100644 --- a/client/src/localization/languages/Nl.tsx +++ b/client/src/localization/languages/Nl.tsx @@ -134,7 +134,7 @@ export default { 'Top-k verandert hoe het model tokens selecteert voor uitvoer. Een top-k van 1 betekent dat het geselecteerde token het meest waarschijnlijk is van alle tokens in de vocabulaire van het model (ook wel \'greedy decoding\' genoemd), terwijl een top-k van 3 betekent dat het volgende token wordt geselecteerd uit de 3 meest waarschijnlijke tokens (met behulp van temperatuur).', com_endpoint_google_maxoutputtokens: ' Maximum aantal tokens dat kan worden gegenereerd in de reactie. Geef een lagere waarde op voor kortere reacties en een hogere waarde voor langere reacties.', - com_endpoint_google_custom_name_placeholder: 'Stel een aangepaste naam in voor PaLM2', + com_endpoint_google_custom_name_placeholder: 'Stel een aangepaste naam in voor Google', com_endpoint_prompt_prefix_placeholder: 'Stel aangepaste instructies of context in. Wordt genegeerd indien leeg.', com_endpoint_custom_name: 'Aangepaste naam', diff --git a/client/src/localization/languages/Pl.tsx b/client/src/localization/languages/Pl.tsx index 6b8df0336..f66e9221d 100644 --- a/client/src/localization/languages/Pl.tsx +++ b/client/src/localization/languages/Pl.tsx @@ -106,7 +106,7 @@ export default { 'Top-k wpływa na sposób, w jaki model wybiera tokeny do wygenerowania odpowiedzi. Top-k 1 oznacza, że wybrany token jest najbardziej prawdopodobny spośród wszystkich tokenów w słowniku modelu (nazywane też dekodowaniem zachłannym), podczas gdy top-k 3 oznacza, że następny token jest wybierany spośród 3 najbardziej prawdopodobnych tokenów (z uwzględnieniem temperatury).', com_endpoint_google_maxoutputtokens: 'Maksymalna liczba tokenów, które mogą być wygenerowane w odpowiedzi. Wybierz niższą wartość dla krótszych odpowiedzi i wyższą wartość dla dłuższych odpowiedzi.', - com_endpoint_google_custom_name_placeholder: 'Ustaw niestandardową nazwę dla PaLM2', + com_endpoint_google_custom_name_placeholder: 'Ustaw niestandardową nazwę dla Google', com_endpoint_google_prompt_prefix_placeholder: 'Ustaw niestandardowe instrukcje lub kontekst. Jeśli puste, zostanie zignorowane.', com_endpoint_custom_name: 'Niestandardowa nazwa', diff --git a/client/src/localization/languages/Ru.tsx b/client/src/localization/languages/Ru.tsx index 4030fce19..f852f4405 100644 --- a/client/src/localization/languages/Ru.tsx +++ b/client/src/localization/languages/Ru.tsx @@ -120,7 +120,7 @@ export default { 'Top K изменяет то, как модель выбирает токены для вывода. Top K равное 1 означает, что выбирается наиболее вероятный токен из всего словаря модели (так называемое жадное декодирование), а Top K равное 3 означает, что следующий токен выбирается из трех наиболее вероятных токенов (с использованием температуры).', com_endpoint_google_maxoutputtokens: 'Максимальное количество токенов, которые могут быть сгенерированы в ответе. Укажите меньшее значение для более коротких ответов и большее значение для более длинных ответов.', - com_endpoint_google_custom_name_placeholder: 'Установите пользовательское имя для PaLM2', + com_endpoint_google_custom_name_placeholder: 'Установите пользовательское имя для Google', com_endpoint_google_prompt_prefix_placeholder: 'Установите пользовательские инструкции или контекст. Игнорируется, если пусто.', com_endpoint_custom_name: 'Пользовательское имя', @@ -180,7 +180,7 @@ export default { com_endpoint_view_options: 'Просмотреть параметры', com_endpoint_save_convo_as_preset: 'Сохранить разговор как предустановку', com_endpoint_presets_clear_warning: - 'Вы уверены, что хотите очистить все предустановки? Эти действия необратимы, и восстановление невозможно.', + 'Вы уверены, что хотите очистить все предустановки? Эти действия необратимы, и восстановление невозможно.', com_endpoint_presets: 'предустановки', com_endpoint_my_preset: 'Моя предустановка', com_endpoint_config_key: 'Установить ключ API', diff --git a/client/src/localization/languages/Sv.tsx b/client/src/localization/languages/Sv.tsx index 3cdf587ed..5ecbaf861 100644 --- a/client/src/localization/languages/Sv.tsx +++ b/client/src/localization/languages/Sv.tsx @@ -129,7 +129,7 @@ export default { 'Top-k ändrar hur modellen väljer tokens för utdata. Ett top-k av 1 innebär att den valda token är den mest sannolika bland alla tokens i modellens vokabulär (kallas också girig avkodning), medan ett top-k av 3 innebär att nästa token väljs bland de 3 mest sannolika tokens (med temperatur).', // Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens in the model's vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature). com_endpoint_google_maxoutputtokens: 'Maximalt antal tokens som kan genereras i svaret. Ange ett lägre värde för kortare svar och ett högre värde för längre svar.', // Maximum number of tokens that can be generated in the response. Specify a lower value for shorter responses and a higher value for longer responses. - com_endpoint_google_custom_name_placeholder: 'Ange ett anpassat namn för PaLM2', // Set a custom name for PaLM2 + com_endpoint_google_custom_name_placeholder: 'Ange ett anpassat namn för Google', // Set a custom name for Google com_endpoint_prompt_prefix_placeholder: 'Ange anpassade instruktioner eller kontext. Ignoreras om tom.', // Set custom instructions or context. Ignored if empty. com_endpoint_custom_name: 'Anpassat namn', // Custom Name diff --git a/client/src/localization/languages/Tr.tsx b/client/src/localization/languages/Tr.tsx index af76630db..688de7556 100644 --- a/client/src/localization/languages/Tr.tsx +++ b/client/src/localization/languages/Tr.tsx @@ -136,7 +136,7 @@ export default { 'Top-k, modelin çıkış için token seçme şeklini değiştirir. 1 top-k, seçilen tokenın modelin kelime dağarcığındaki tüm tokenlar arasında en olası olduğu anlamına gelir (ayrıca aç gözlü kod çözme denir), 3 top-k ise bir sonraki tokenın 3 en olası token arasından seçildiği anlamına gelir (sıcaklık kullanılarak).', com_endpoint_google_maxoutputtokens: 'Yanıtta üretilebilecek maksimum token sayısı. Daha kısa yanıtlar için daha düşük bir değer belirtin ve daha uzun yanıtlar için daha yüksek bir değer belirtin.', - com_endpoint_google_custom_name_placeholder: 'PaLM2 için özel bir ad belirleyin', + com_endpoint_google_custom_name_placeholder: 'Google için özel bir ad belirleyin', com_endpoint_prompt_prefix_placeholder: 'Özel talimatları veya bağlamı ayarlayın. Boşsa göz ardı edilir.', com_endpoint_custom_name: 'Özel Ad', @@ -175,15 +175,18 @@ export default { com_endpoint_disabled_with_tools: 'araçlarla devre dışı bırakıldı', com_endpoint_disabled_with_tools_placeholder: 'Araçlar Seçiliyken Devre Dışı Bırakıldı', com_endpoint_plug_set_custom_instructions_for_gpt_placeholder: - 'Sistem Mesajı\'na dahil edilecek özel talimatları ayarlayın. Varsayılan: hiçbiri', + 'Sistem Mesajı\'na dahil edilecek özel talimatları ayarlayın. Varsayılan: hiçbiri', com_endpoint_import: 'İçe Aktar', com_endpoint_set_custom_name: 'Bu ön ayarı bulabilmeniz için özel bir ad belirleyin', com_endpoint_preset_delete_confirm: 'Bu ön ayarı silmek istediğinizden emin misiniz?', com_endpoint_preset_clear_all_confirm: 'Tüm ön ayarlarınızı silmek istediğinizden emin misiniz?', com_endpoint_preset_import: 'Ön Ayar İçe Aktarıldı!', - com_endpoint_preset_import_error: 'Ön ayarınız içe aktarılırken bir hata oluştu. Lütfen tekrar deneyin.', - com_endpoint_preset_save_error: 'Ön ayarınız kaydedilirken bir hata oluştu. Lütfen tekrar deneyin.', - com_endpoint_preset_delete_error: 'Ön ayarınız silinirken bir hata oluştu. Lütfen tekrar deneyin.', + com_endpoint_preset_import_error: + 'Ön ayarınız içe aktarılırken bir hata oluştu. Lütfen tekrar deneyin.', + com_endpoint_preset_save_error: + 'Ön ayarınız kaydedilirken bir hata oluştu. Lütfen tekrar deneyin.', + com_endpoint_preset_delete_error: + 'Ön ayarınız silinirken bir hata oluştu. Lütfen tekrar deneyin.', com_endpoint_preset_default_removed: 'artık varsayılan ön ayar değildir.', com_endpoint_preset_default_item: 'Varsayılan:', com_endpoint_preset_default_none: 'Varsayılan ön ayar etkin değil.', diff --git a/client/src/localization/languages/Vi.tsx b/client/src/localization/languages/Vi.tsx index ea9e2f578..60fb18b38 100644 --- a/client/src/localization/languages/Vi.tsx +++ b/client/src/localization/languages/Vi.tsx @@ -134,7 +134,7 @@ export default { 'Top-k thay đổi cách mô hình chọn mã thông báo để xuất. Top-k là 1 có nghĩa là mã thông báo được chọn là phổ biến nhất trong tất cả các mã thông báo trong bảng từ vựng của mô hình (còn được gọi là giải mã tham lam), trong khi top-k là 3 có nghĩa là mã thông báo tiếp theo được chọn từ giữa 3 mã thông báo phổ biến nhất (sử dụng nhiệt độ).', com_endpoint_google_maxoutputtokens: 'Số mã thông báo tối đa có thể được tạo ra trong phản hồi. Chỉ định một giá trị thấp hơn cho các phản hồi ngắn hơn và một giá trị cao hơn cho các phản hồi dài hơn.', - com_endpoint_google_custom_name_placeholder: 'Đặt tên tùy chỉnh cho PaLM2', + com_endpoint_google_custom_name_placeholder: 'Đặt tên tùy chỉnh cho Google', com_endpoint_prompt_prefix_placeholder: 'Đặt hướng dẫn hoặc ngữ cảnh tùy chỉnh. Bỏ qua nếu trống.', com_endpoint_custom_name: 'Tên tùy chỉnh', diff --git a/client/src/localization/languages/Zh.tsx b/client/src/localization/languages/Zh.tsx index cbd593b50..01ef650af 100644 --- a/client/src/localization/languages/Zh.tsx +++ b/client/src/localization/languages/Zh.tsx @@ -126,7 +126,7 @@ export default { 'Top-k 会改变模型选择输出词的方式。top-k为1意味着所选词是模型词汇中概率最大的(也称为贪心解码),而top-k为3意味着下一个词是从3个概率最大的词中选出的(使用随机性)。', com_endpoint_google_maxoutputtokens: ' 响应生成中可以使用的最大词元数。指定较低的值会得到更短的响应,而指定较高的值则会得到更长的响应。', - com_endpoint_google_custom_name_placeholder: '为PaLM2设置一个名称', + com_endpoint_google_custom_name_placeholder: '为Google设置一个名称', com_endpoint_prompt_prefix_placeholder: '自定义提示词和上下文,默认为空', com_endpoint_custom_name: '自定义名称', com_endpoint_prompt_prefix: '对话前缀', diff --git a/client/src/localization/languages/ZhTraditional.tsx b/client/src/localization/languages/ZhTraditional.tsx index 0e7fb8a76..7705dc878 100644 --- a/client/src/localization/languages/ZhTraditional.tsx +++ b/client/src/localization/languages/ZhTraditional.tsx @@ -127,7 +127,7 @@ export default { 'Top-k 調整模型如何選取輸出的 token。當 Top-k 設為 1 時,模型會選取在其詞彙庫中機率最高的 token 進行輸出(這也被稱為貪婪解碼)。相對地,當 Top-k 設為 3 時,模型會從機率最高的三個 token 中選取下一個輸出 token(這會涉及到所謂的「溫度」調整)', com_endpoint_google_maxoutputtokens: '設定回應中可生成的最大 token 數。若希望回應簡短,請設定較低的數值;若需較長的回應,則設定較高的數值。', - com_endpoint_google_custom_name_placeholder: '為 PaLM2 設定自定義名稱', + com_endpoint_google_custom_name_placeholder: '為 Google 設定自定義名稱', com_endpoint_prompt_prefix_placeholder: '設定自定義提示或前後文。如果為空則忽略。', com_endpoint_custom_name: '自定義名稱', com_endpoint_prompt_prefix: '提示起始字串', diff --git a/docs/deployment/huggingface.md b/docs/deployment/huggingface.md index faa0d5ad1..11bea15e5 100644 --- a/docs/deployment/huggingface.md +++ b/docs/deployment/huggingface.md @@ -36,7 +36,7 @@ You will need to fill these values: | BINGAI_TOKEN | `user_provided` | | CHATGPT_TOKEN | `user_provided` | | ANTHROPIC_API_KEY | `user_provided` | -| PALM_KEY | `user_provided` | +| GOOGLE_KEY | `user_provided` | | CREDS_KEY | * see bellow | | CREDS_IV | * see bellow | | JWT_SECRET | * see bellow | diff --git a/docs/deployment/render.md b/docs/deployment/render.md index 3a3e49ce9..768e91656 100644 --- a/docs/deployment/render.md +++ b/docs/deployment/render.md @@ -53,7 +53,7 @@ Also: | JWT_REFRESH_SECRET | secret | | JWT_SECRET | secret | | OPENAI_API_KEY | user_provided | -| PALM_KEY | user_provided | +| GOOGLE_KEY | user_provided | | PORT | 3080 | | SESSION_EXPIRY | (1000 * 60 * 60 * 24) * 7 | diff --git a/docs/features/plugins/introduction.md b/docs/features/plugins/introduction.md index 9133dfcdf..3f1bd376c 100644 --- a/docs/features/plugins/introduction.md +++ b/docs/features/plugins/introduction.md @@ -6,7 +6,7 @@ The plugins endpoint opens the door to prompting LLMs in new ways other than tra The first step is using chain-of-thought prompting & ["agency"](https://zapier.com/blog/ai-agent/) for using plugins/tools in a fashion mimicing the official ChatGPT Plugins feature. -More than this, you can use this endpoint for changing your conversation settings mid-conversation. Unlike the official ChatGPT site and all other endpoints, you can switch models, presets, and settings mid-convo, even when you have no plugins selected. This is useful if you first want a creative response from GPT-4, and then a deterministic, lower cost response from GPT-3. Soon, you will be able to use PaLM2 and HuggingFace models, all in this endpoint in the same modular manner. +More than this, you can use this endpoint for changing your conversation settings mid-conversation. Unlike the official ChatGPT site and all other endpoints, you can switch models, presets, and settings mid-convo, even when you have no plugins selected. This is useful if you first want a creative response from GPT-4, and then a deterministic, lower cost response from GPT-3. Soon, you will be able to use Google, HuggingFace, local models, all in this or a similar endpoint in the same modular manner. ### Roadmap: - More plugins and advanced plugin usage (ongoing) diff --git a/docs/index.md b/docs/index.md index e2602ed94..7dbbd55c8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -36,7 +36,7 @@ - 🌎 Multilingual UI: - English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro, Русский - 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands - - 🤖 AI model selection: OpenAI API, Azure, BingAI, ChatGPT Browser, PaLM2, Anthropic (Claude), Plugins + - 🤖 AI model selection: OpenAI API, Azure, BingAI, ChatGPT, Google Vertex AI, Anthropic (Claude), Plugins - 💾 Create, Save, & Share Custom Presets - 🔄 Edit, Resubmit, and Continue messages with conversation branching - 📤 Export conversations as screenshots, markdown, text, json. diff --git a/docs/install/apis_and_tokens.md b/docs/install/apis_and_tokens.md index c3c7262b1..d03b377a7 100644 --- a/docs/install/apis_and_tokens.md +++ b/docs/install/apis_and_tokens.md @@ -46,27 +46,35 @@ To get your Bing Access Token, you have a few options: - Go to [https://console.anthropic.com/account/keys](https://console.anthropic.com/account/keys) and get your api key - add it to `ANTHROPIC_API_KEY=` in the `.env` file -## Google's PaLM 2 +## Google LLMs -To setup PaLM 2 (via Google Cloud Vertex AI API), you need to: +To setup Google LLMs (via Google Cloud Vertex AI), first, signup for Google Cloud: https://cloud.google.com/ -### Enable the Vertex AI API on Google Cloud: - - Go to [https://console.cloud.google.com/vertex-ai](https://console.cloud.google.com/vertex-ai) +You can usually get $300 starting credit, which makes this option free for 90 days. + +### Once signed up, Enable the Vertex AI API on Google Cloud: + - Go to [Vertex AI page on Google Cloud console](https://console.cloud.google.com/vertex-ai) - Click on "Enable API" if prompted -### Create a Service Account: - - Go to [https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1) - - Select or create a project - - Enter a service account name and description - - Click on "Create and Continue" to give at least the "Vertex AI User" role - - Click on "Done" -### Create a JSON key, rename as 'auth.json' and save it in /api/data/: - - Go back to [https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts) - - Select your service account - - Click on "Keys" - - Click on "Add Key" and then "Create new key" - - Choose JSON as the key type and click on "Create" - - Download the key file and rename it as 'auth.json' - - Save it in `/api/data/` +### Create a Service Account with Vertex AI role: + - **[Click here to create a Service Account](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1)** + - **Select or create a project** + - ### Enter a service account ID (required), name and description are optional + - ![image](https://github.com/danny-avila/LibreChat/assets/110412045/0c5cd177-029b-44fa-a398-a794aeb09de6) + - ### Click on "Create and Continue" to give at least the "Vertex AI User" role + - ![image](https://github.com/danny-avila/LibreChat/assets/110412045/22d3a080-e71e-446e-8485-bcc5bf558dbb) + - **Click on "Continue/Done"** +### Create a JSON key to Save in Project Directory: + - **Go back to [the Service Accounts page](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts)** + - **Select your service account** + - ### Click on "Keys" + - ![image](https://github.com/danny-avila/LibreChat/assets/110412045/735a7bbe-25a6-4b4c-9bb5-e0d8aa91be3d) + - ### Click on "Add Key" and then "Create new key" + - ![image](https://github.com/danny-avila/LibreChat/assets/110412045/cfbb20d3-94a8-4cd1-ac39-f9cd8c2fceaa) + - **Choose JSON as the key type and click on "Create"** + - **Download the key file and rename it as 'auth.json'** + - **Save it within the project directory, in `/api/data/`** + - ![image](https://github.com/danny-avila/LibreChat/assets/110412045/f5b8bcb5-1b20-4751-81a1-d3757a4b3f2f) + ## Azure OpenAI diff --git a/docs/install/dotenv.md b/docs/install/dotenv.md index 48e64dd8d..7319e24ff 100644 --- a/docs/install/dotenv.md +++ b/docs/install/dotenv.md @@ -249,7 +249,7 @@ OPENROUTER_API_KEY= Follow these instruction to setup: [Google PaLM 2](./apis_and_tokens.md#googles-palm-2) ```bash -PALM_KEY=user_provided +GOOGLE_KEY=user_provided GOOGLE_REVERSE_PROXY= ``` diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index f8904b8ea..05c9ef7ab 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -29,10 +29,44 @@ export const alternateName = { [EModelEndpoint.bingAI]: 'Bing', [EModelEndpoint.chatGPTBrowser]: 'ChatGPT', [EModelEndpoint.gptPlugins]: 'Plugins', - [EModelEndpoint.google]: 'PaLM', + [EModelEndpoint.google]: 'Google', [EModelEndpoint.anthropic]: 'Anthropic', }; +export const endpointSettings = { + [EModelEndpoint.google]: { + model: { + default: 'chat-bison', + }, + maxOutputTokens: { + min: 1, + max: 2048, + step: 1, + default: 1024, + }, + temperature: { + min: 0, + max: 1, + step: 0.01, + default: 0.2, + }, + topP: { + min: 0, + max: 1, + step: 0.01, + default: 0.8, + }, + topK: { + min: 1, + max: 40, + step: 0.01, + default: 40, + }, + }, +}; + +const google = endpointSettings[EModelEndpoint.google]; + export const EndpointURLs: { [key in EModelEndpoint]: string } = { [EModelEndpoint.azureOpenAI]: '/api/ask/azureOpenAI', [EModelEndpoint.openAI]: '/api/ask/openAI', @@ -275,22 +309,24 @@ export const googleSchema = tConversationSchema }) .transform((obj) => ({ ...obj, - model: obj.model ?? 'chat-bison', + model: obj.model ?? google.model.default, modelLabel: obj.modelLabel ?? null, promptPrefix: obj.promptPrefix ?? null, - temperature: obj.temperature ?? 0.2, - maxOutputTokens: obj.maxOutputTokens ?? 1024, - topP: obj.topP ?? 0.95, - topK: obj.topK ?? 40, + examples: obj.examples ?? [{ input: { content: '' }, output: { content: '' } }], + temperature: obj.temperature ?? google.temperature.default, + maxOutputTokens: obj.maxOutputTokens ?? google.maxOutputTokens.default, + topP: obj.topP ?? google.topP.default, + topK: obj.topK ?? google.topK.default, })) .catch(() => ({ - model: 'chat-bison', + model: google.model.default, modelLabel: null, promptPrefix: null, - temperature: 0.2, - maxOutputTokens: 1024, - topP: 0.95, - topK: 40, + examples: [{ input: { content: '' }, output: { content: '' } }], + temperature: google.temperature.default, + maxOutputTokens: google.maxOutputTokens.default, + topP: google.topP.default, + topK: google.topK.default, })); export const bingAISchema = tConversationSchema @@ -539,7 +575,13 @@ export const getResponseSender = (endpointOption: TEndpointOption): string => { } if (endpoint === EModelEndpoint.google) { - return modelLabel ?? 'PaLM2'; + if (modelLabel) { + return modelLabel; + } else if (model && model.includes('code')) { + return 'Codey'; + } + + return 'PaLM2'; } return ''; @@ -590,19 +632,19 @@ export const compactGoogleSchema = tConversationSchema }) .transform((obj) => { const newObj: Partial = { ...obj }; - if (newObj.model === 'chat-bison') { + if (newObj.model === google.model.default) { delete newObj.model; } - if (newObj.temperature === 0.2) { + if (newObj.temperature === google.temperature.default) { delete newObj.temperature; } - if (newObj.maxOutputTokens === 1024) { + if (newObj.maxOutputTokens === google.maxOutputTokens.default) { delete newObj.maxOutputTokens; } - if (newObj.topP === 0.95) { + if (newObj.topP === google.topP.default) { delete newObj.topP; } - if (newObj.topK === 40) { + if (newObj.topK === google.topK.default) { delete newObj.topK; }