mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-18 17:30:16 +01:00
front and backend logic for model switching
This commit is contained in:
parent
a5afd5c48f
commit
c00a2c902b
9 changed files with 58 additions and 39 deletions
|
|
@ -1,28 +1,24 @@
|
||||||
require('dotenv').config();
|
require('dotenv').config();
|
||||||
// const store = new Keyv(process.env.MONGODB_URI);
|
|
||||||
const Keyv = require('keyv');
|
const Keyv = require('keyv');
|
||||||
const { KeyvFile } = require('keyv-file');
|
const { KeyvFile } = require('keyv-file');
|
||||||
|
|
||||||
const clientOptions = {
|
const proxyOptions = {
|
||||||
// (Optional) Support for a reverse proxy for the completions endpoint (private API server).
|
|
||||||
// Warning: This will expose your `openaiApiKey` to a third-party. Consider the risks before using this.
|
|
||||||
reverseProxyUrl: 'https://chatgpt.pawan.krd/api/completions',
|
reverseProxyUrl: 'https://chatgpt.pawan.krd/api/completions',
|
||||||
// (Optional) Parameters as described in https://platform.openai.com/docs/api-reference/completions
|
|
||||||
modelOptions: {
|
modelOptions: {
|
||||||
// You can override the model name and any other parameters here.
|
|
||||||
model: 'text-davinci-002-render'
|
model: 'text-davinci-002-render'
|
||||||
},
|
},
|
||||||
// (Optional) Set custom instructions instead of "You are ChatGPT...".
|
|
||||||
// promptPrefix: 'You are Bob, a cowboy in Western times...',
|
|
||||||
// (Optional) Set a custom name for the user
|
|
||||||
// userLabel: 'User',
|
|
||||||
// (Optional) Set a custom name for ChatGPT
|
|
||||||
// chatGptLabel: 'ChatGPT',
|
|
||||||
// (Optional) Set to true to enable `console.debug()` logging
|
|
||||||
debug: false
|
debug: false
|
||||||
};
|
};
|
||||||
|
|
||||||
const askClient = async (question, progressCallback, convo) => {
|
const davinciOptions = {
|
||||||
|
modelOptions: {
|
||||||
|
model: 'text-davinci-003'
|
||||||
|
},
|
||||||
|
debug: false
|
||||||
|
};
|
||||||
|
|
||||||
|
const askClient = async ({ model, text, progressCallback, convo }) => {
|
||||||
|
const clientOptions = model === 'chatgpt' ? proxyOptions : davinciOptions;
|
||||||
const ChatGPTClient = (await import('@waylaidwanderer/chatgpt-api')).default;
|
const ChatGPTClient = (await import('@waylaidwanderer/chatgpt-api')).default;
|
||||||
const client = new ChatGPTClient(process.env.CHATGPT_TOKEN, clientOptions, {
|
const client = new ChatGPTClient(process.env.CHATGPT_TOKEN, clientOptions, {
|
||||||
store: new KeyvFile({ filename: 'cache.json' })
|
store: new KeyvFile({ filename: 'cache.json' })
|
||||||
|
|
@ -36,7 +32,7 @@ const askClient = async (question, progressCallback, convo) => {
|
||||||
options = { ...options, ...convo };
|
options = { ...options, ...convo };
|
||||||
}
|
}
|
||||||
|
|
||||||
const res = await client.sendMessage(question, options);
|
const res = await client.sendMessage(text, options);
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -16,7 +16,7 @@ const sendMessage = (res, message) => {
|
||||||
};
|
};
|
||||||
|
|
||||||
router.post('/', async (req, res) => {
|
router.post('/', async (req, res) => {
|
||||||
const { text, parentMessageId, conversationId } = req.body;
|
const { model, text, parentMessageId, conversationId } = req.body;
|
||||||
if (!text.trim().includes(' ') && text.length < 5) {
|
if (!text.trim().includes(' ') && text.length < 5) {
|
||||||
return handleError(res, 'Prompt empty or too short');
|
return handleError(res, 'Prompt empty or too short');
|
||||||
}
|
}
|
||||||
|
|
@ -24,7 +24,7 @@ router.post('/', async (req, res) => {
|
||||||
const userMessageId = crypto.randomUUID();
|
const userMessageId = crypto.randomUUID();
|
||||||
let userMessage = { id: userMessageId, sender: 'User', text };
|
let userMessage = { id: userMessageId, sender: 'User', text };
|
||||||
|
|
||||||
console.log('ask log', { ...userMessage, parentMessageId, conversationId });
|
console.log('ask log', { model, ...userMessage, parentMessageId, conversationId });
|
||||||
|
|
||||||
res.writeHead(200, {
|
res.writeHead(200, {
|
||||||
Connection: 'keep-alive',
|
Connection: 'keep-alive',
|
||||||
|
|
@ -54,9 +54,14 @@ router.post('/', async (req, res) => {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let gptResponse = await askClient(text, progressCallback, {
|
let gptResponse = await askClient({
|
||||||
|
model,
|
||||||
|
text,
|
||||||
|
progressCallback,
|
||||||
|
convo: {
|
||||||
parentMessageId,
|
parentMessageId,
|
||||||
conversationId
|
conversationId
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log('CLIENT RESPONSE', gptResponse);
|
console.log('CLIENT RESPONSE', gptResponse);
|
||||||
|
|
@ -70,7 +75,9 @@ router.post('/', async (req, res) => {
|
||||||
gptResponse.id = gptResponse.messageId;
|
gptResponse.id = gptResponse.messageId;
|
||||||
gptResponse.parentMessageId = gptResponse.messageId;
|
gptResponse.parentMessageId = gptResponse.messageId;
|
||||||
userMessage.parentMessageId = parentMessageId ? parentMessageId : gptResponse.messageId;
|
userMessage.parentMessageId = parentMessageId ? parentMessageId : gptResponse.messageId;
|
||||||
userMessage.conversationId = conversationId ? conversationId : gptResponse.conversationId;
|
userMessage.conversationId = conversationId
|
||||||
|
? conversationId
|
||||||
|
: gptResponse.conversationId;
|
||||||
await saveMessage(userMessage);
|
await saveMessage(userMessage);
|
||||||
delete gptResponse.response;
|
delete gptResponse.response;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import React, { useEffect, useState, useRef } from 'react';
|
import React, { useEffect, useState, useRef } from 'react';
|
||||||
import Message from './Message';
|
|
||||||
import ScrollToBottom from './ScrollToBottom';
|
|
||||||
import { CSSTransition } from 'react-transition-group';
|
import { CSSTransition } from 'react-transition-group';
|
||||||
|
import ScrollToBottom from './ScrollToBottom';
|
||||||
|
import Message from './Message';
|
||||||
|
|
||||||
const Messages = ({ messages }) => {
|
const Messages = ({ messages }) => {
|
||||||
const [showScrollButton, setShowScrollButton] = useState(false);
|
const [showScrollButton, setShowScrollButton] = useState(false);
|
||||||
|
|
@ -13,7 +13,7 @@ const Messages = ({ messages }) => {
|
||||||
const scrollable = scrollableRef.current;
|
const scrollable = scrollableRef.current;
|
||||||
const hasScrollbar = scrollable.scrollHeight > scrollable.clientHeight;
|
const hasScrollbar = scrollable.scrollHeight > scrollable.clientHeight;
|
||||||
setShowScrollButton(hasScrollbar);
|
setShowScrollButton(hasScrollbar);
|
||||||
}, 850);
|
}, 650);
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
clearTimeout(timeoutId);
|
clearTimeout(timeoutId);
|
||||||
|
|
@ -69,7 +69,7 @@ const Messages = ({ messages }) => {
|
||||||
))}
|
))}
|
||||||
<CSSTransition
|
<CSSTransition
|
||||||
in={showScrollButton}
|
in={showScrollButton}
|
||||||
timeout={650}
|
timeout={400}
|
||||||
classNames="scroll-down"
|
classNames="scroll-down"
|
||||||
unmountOnExit={false}
|
unmountOnExit={false}
|
||||||
appear
|
appear
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import React, { useState } from 'react';
|
import React, { useState } from 'react';
|
||||||
import { useSelector, useDispatch } from 'react-redux';
|
import { useSelector, useDispatch } from 'react-redux';
|
||||||
import { setConversation, setError } from '~/store/convoSlice';
|
import { setModel } from '~/store/submitSlice';
|
||||||
import GPTIcon from '../svg/GPTIcon';
|
import GPTIcon from '../svg/GPTIcon';
|
||||||
import { DropdownMenuCheckboxItemProps } from '@radix-ui/react-dropdown-menu';
|
import { DropdownMenuCheckboxItemProps } from '@radix-ui/react-dropdown-menu';
|
||||||
|
|
||||||
|
|
@ -17,7 +17,11 @@ import {
|
||||||
} from '../ui/DropdownMenu.tsx';
|
} from '../ui/DropdownMenu.tsx';
|
||||||
|
|
||||||
export default function ModelMenu() {
|
export default function ModelMenu() {
|
||||||
const [model, setModel] = useState('chatgpt');
|
const dispatch = useDispatch();
|
||||||
|
const { model } = useSelector((state) => state.submit);
|
||||||
|
const onChange = (value) => {
|
||||||
|
dispatch(setModel(value));
|
||||||
|
};
|
||||||
|
|
||||||
const defaultColorProps = [
|
const defaultColorProps = [
|
||||||
'text-gray-500',
|
'text-gray-500',
|
||||||
|
|
@ -57,7 +61,7 @@ export default function ModelMenu() {
|
||||||
<DropdownMenuSeparator />
|
<DropdownMenuSeparator />
|
||||||
<DropdownMenuRadioGroup
|
<DropdownMenuRadioGroup
|
||||||
value={model}
|
value={model}
|
||||||
onValueChange={setModel}
|
onValueChange={onChange}
|
||||||
>
|
>
|
||||||
<DropdownMenuRadioItem value="chatgpt">ChatGPT</DropdownMenuRadioItem>
|
<DropdownMenuRadioItem value="chatgpt">ChatGPT</DropdownMenuRadioItem>
|
||||||
<DropdownMenuRadioItem value="davinci">Davinci</DropdownMenuRadioItem>
|
<DropdownMenuRadioItem value="davinci">Davinci</DropdownMenuRadioItem>
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ export default function TextChat({ messages }) {
|
||||||
const [errorMessage, setErrorMessage] = useState('');
|
const [errorMessage, setErrorMessage] = useState('');
|
||||||
const dispatch = useDispatch();
|
const dispatch = useDispatch();
|
||||||
const convo = useSelector((state) => state.convo);
|
const convo = useSelector((state) => state.convo);
|
||||||
const { isSubmitting } = useSelector((state) => state.submit);
|
const { isSubmitting, model } = useSelector((state) => state.submit);
|
||||||
const { text } = useSelector((state) => state.text);
|
const { text } = useSelector((state) => state.text);
|
||||||
const { error } = convo;
|
const { error } = convo;
|
||||||
|
|
||||||
|
|
@ -28,8 +28,8 @@ export default function TextChat({ messages }) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(setSubmitState(true));
|
dispatch(setSubmitState(true));
|
||||||
const payload = text.trim();
|
const message = text.trim();
|
||||||
const currentMsg = { sender: 'User', text: payload, current: true };
|
const currentMsg = { sender: 'User', text: message, current: true };
|
||||||
const initialResponse = { sender: 'GPT', text: '' };
|
const initialResponse = { sender: 'GPT', text: '' };
|
||||||
dispatch(setMessages([...messages, currentMsg, initialResponse]));
|
dispatch(setMessages([...messages, currentMsg, initialResponse]));
|
||||||
dispatch(setText(''));
|
dispatch(setText(''));
|
||||||
|
|
@ -59,12 +59,20 @@ export default function TextChat({ messages }) {
|
||||||
setErrorMessage(event.data);
|
setErrorMessage(event.data);
|
||||||
dispatch(setSubmitState(false));
|
dispatch(setSubmitState(false));
|
||||||
dispatch(setMessages([...messages.slice(0, -2), currentMsg, errorResponse]));
|
dispatch(setMessages([...messages.slice(0, -2), currentMsg, errorResponse]));
|
||||||
dispatch(setText(payload));
|
dispatch(setText(message));
|
||||||
dispatch(setError(true));
|
dispatch(setError(true));
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
console.log('User Input:', payload);
|
const submission = {
|
||||||
handleSubmit({ text: payload, messageHandler, convo, convoHandler, errorHandler });
|
model,
|
||||||
|
text: message,
|
||||||
|
convo,
|
||||||
|
messageHandler,
|
||||||
|
convoHandler,
|
||||||
|
errorHandler
|
||||||
|
};
|
||||||
|
console.log('User Input:', message);
|
||||||
|
handleSubmit(submission);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleKeyDown = (e) => {
|
const handleKeyDown = (e) => {
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ const currentSlice = createSlice({
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
//
|
|
||||||
|
|
||||||
export const { setConversation, setError } = currentSlice.actions;
|
export const { setConversation, setError } = currentSlice.actions;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
|
||||||
const initialState = {
|
const initialState = {
|
||||||
isSubmitting: false,
|
isSubmitting: false,
|
||||||
|
model: 'chatgpt'
|
||||||
};
|
};
|
||||||
|
|
||||||
const currentSlice = createSlice({
|
const currentSlice = createSlice({
|
||||||
|
|
@ -11,9 +12,12 @@ const currentSlice = createSlice({
|
||||||
setSubmitState: (state, action) => {
|
setSubmitState: (state, action) => {
|
||||||
state.isSubmitting = action.payload;
|
state.isSubmitting = action.payload;
|
||||||
},
|
},
|
||||||
|
setModel: (state, action) => {
|
||||||
|
state.model = action.payload;
|
||||||
|
},
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
export const { setSubmitState } = currentSlice.actions;
|
export const { setSubmitState, setModel } = currentSlice.actions;
|
||||||
|
|
||||||
export default currentSlice.reducer;
|
export default currentSlice.reducer;
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
import { SSE } from '../../app/sse';
|
import { SSE } from '../../app/sse';
|
||||||
|
|
||||||
export default function handleSubmit({
|
export default function handleSubmit({
|
||||||
|
model,
|
||||||
text,
|
text,
|
||||||
convo,
|
convo,
|
||||||
messageHandler,
|
messageHandler,
|
||||||
convoHandler,
|
convoHandler,
|
||||||
errorHandler
|
errorHandler
|
||||||
}) {
|
}) {
|
||||||
let payload = { text };
|
let payload = { model, text };
|
||||||
if (convo.conversationId && convo.parentMessageId) {
|
if (convo.conversationId && convo.parentMessageId) {
|
||||||
payload = {
|
payload = {
|
||||||
...payload,
|
...payload,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue