front and backend logic for model switching

This commit is contained in:
Daniel Avila 2023-02-13 21:15:28 -05:00
parent a5afd5c48f
commit c00a2c902b
9 changed files with 58 additions and 39 deletions

View file

@ -1,28 +1,24 @@
require('dotenv').config(); require('dotenv').config();
// const store = new Keyv(process.env.MONGODB_URI);
const Keyv = require('keyv'); const Keyv = require('keyv');
const { KeyvFile } = require('keyv-file'); const { KeyvFile } = require('keyv-file');
const clientOptions = { const proxyOptions = {
// (Optional) Support for a reverse proxy for the completions endpoint (private API server).
// Warning: This will expose your `openaiApiKey` to a third-party. Consider the risks before using this.
reverseProxyUrl: 'https://chatgpt.pawan.krd/api/completions', reverseProxyUrl: 'https://chatgpt.pawan.krd/api/completions',
// (Optional) Parameters as described in https://platform.openai.com/docs/api-reference/completions
modelOptions: { modelOptions: {
// You can override the model name and any other parameters here.
model: 'text-davinci-002-render' model: 'text-davinci-002-render'
}, },
// (Optional) Set custom instructions instead of "You are ChatGPT...".
// promptPrefix: 'You are Bob, a cowboy in Western times...',
// (Optional) Set a custom name for the user
// userLabel: 'User',
// (Optional) Set a custom name for ChatGPT
// chatGptLabel: 'ChatGPT',
// (Optional) Set to true to enable `console.debug()` logging
debug: false debug: false
}; };
const askClient = async (question, progressCallback, convo) => { const davinciOptions = {
modelOptions: {
model: 'text-davinci-003'
},
debug: false
};
const askClient = async ({ model, text, progressCallback, convo }) => {
const clientOptions = model === 'chatgpt' ? proxyOptions : davinciOptions;
const ChatGPTClient = (await import('@waylaidwanderer/chatgpt-api')).default; const ChatGPTClient = (await import('@waylaidwanderer/chatgpt-api')).default;
const client = new ChatGPTClient(process.env.CHATGPT_TOKEN, clientOptions, { const client = new ChatGPTClient(process.env.CHATGPT_TOKEN, clientOptions, {
store: new KeyvFile({ filename: 'cache.json' }) store: new KeyvFile({ filename: 'cache.json' })
@ -36,7 +32,7 @@ const askClient = async (question, progressCallback, convo) => {
options = { ...options, ...convo }; options = { ...options, ...convo };
} }
const res = await client.sendMessage(question, options); const res = await client.sendMessage(text, options);
return res; return res;
}; };

File diff suppressed because one or more lines are too long

View file

@ -16,7 +16,7 @@ const sendMessage = (res, message) => {
}; };
router.post('/', async (req, res) => { router.post('/', async (req, res) => {
const { text, parentMessageId, conversationId } = req.body; const { model, text, parentMessageId, conversationId } = req.body;
if (!text.trim().includes(' ') && text.length < 5) { if (!text.trim().includes(' ') && text.length < 5) {
return handleError(res, 'Prompt empty or too short'); return handleError(res, 'Prompt empty or too short');
} }
@ -24,7 +24,7 @@ router.post('/', async (req, res) => {
const userMessageId = crypto.randomUUID(); const userMessageId = crypto.randomUUID();
let userMessage = { id: userMessageId, sender: 'User', text }; let userMessage = { id: userMessageId, sender: 'User', text };
console.log('ask log', { ...userMessage, parentMessageId, conversationId }); console.log('ask log', { model, ...userMessage, parentMessageId, conversationId });
res.writeHead(200, { res.writeHead(200, {
Connection: 'keep-alive', Connection: 'keep-alive',
@ -54,9 +54,14 @@ router.post('/', async (req, res) => {
} }
}; };
let gptResponse = await askClient(text, progressCallback, { let gptResponse = await askClient({
parentMessageId, model,
conversationId text,
progressCallback,
convo: {
parentMessageId,
conversationId
}
}); });
console.log('CLIENT RESPONSE', gptResponse); console.log('CLIENT RESPONSE', gptResponse);
@ -70,7 +75,9 @@ router.post('/', async (req, res) => {
gptResponse.id = gptResponse.messageId; gptResponse.id = gptResponse.messageId;
gptResponse.parentMessageId = gptResponse.messageId; gptResponse.parentMessageId = gptResponse.messageId;
userMessage.parentMessageId = parentMessageId ? parentMessageId : gptResponse.messageId; userMessage.parentMessageId = parentMessageId ? parentMessageId : gptResponse.messageId;
userMessage.conversationId = conversationId ? conversationId : gptResponse.conversationId; userMessage.conversationId = conversationId
? conversationId
: gptResponse.conversationId;
await saveMessage(userMessage); await saveMessage(userMessage);
delete gptResponse.response; delete gptResponse.response;
} }
@ -96,4 +103,4 @@ router.post('/', async (req, res) => {
} }
}); });
module.exports = router; module.exports = router;

View file

@ -1,7 +1,7 @@
import React, { useEffect, useState, useRef } from 'react'; import React, { useEffect, useState, useRef } from 'react';
import Message from './Message';
import ScrollToBottom from './ScrollToBottom';
import { CSSTransition } from 'react-transition-group'; import { CSSTransition } from 'react-transition-group';
import ScrollToBottom from './ScrollToBottom';
import Message from './Message';
const Messages = ({ messages }) => { const Messages = ({ messages }) => {
const [showScrollButton, setShowScrollButton] = useState(false); const [showScrollButton, setShowScrollButton] = useState(false);
@ -13,7 +13,7 @@ const Messages = ({ messages }) => {
const scrollable = scrollableRef.current; const scrollable = scrollableRef.current;
const hasScrollbar = scrollable.scrollHeight > scrollable.clientHeight; const hasScrollbar = scrollable.scrollHeight > scrollable.clientHeight;
setShowScrollButton(hasScrollbar); setShowScrollButton(hasScrollbar);
}, 850); }, 650);
return () => { return () => {
clearTimeout(timeoutId); clearTimeout(timeoutId);
@ -69,7 +69,7 @@ const Messages = ({ messages }) => {
))} ))}
<CSSTransition <CSSTransition
in={showScrollButton} in={showScrollButton}
timeout={650} timeout={400}
classNames="scroll-down" classNames="scroll-down"
unmountOnExit={false} unmountOnExit={false}
appear appear

View file

@ -1,6 +1,6 @@
import React, { useState } from 'react'; import React, { useState } from 'react';
import { useSelector, useDispatch } from 'react-redux'; import { useSelector, useDispatch } from 'react-redux';
import { setConversation, setError } from '~/store/convoSlice'; import { setModel } from '~/store/submitSlice';
import GPTIcon from '../svg/GPTIcon'; import GPTIcon from '../svg/GPTIcon';
import { DropdownMenuCheckboxItemProps } from '@radix-ui/react-dropdown-menu'; import { DropdownMenuCheckboxItemProps } from '@radix-ui/react-dropdown-menu';
@ -17,7 +17,11 @@ import {
} from '../ui/DropdownMenu.tsx'; } from '../ui/DropdownMenu.tsx';
export default function ModelMenu() { export default function ModelMenu() {
const [model, setModel] = useState('chatgpt'); const dispatch = useDispatch();
const { model } = useSelector((state) => state.submit);
const onChange = (value) => {
dispatch(setModel(value));
};
const defaultColorProps = [ const defaultColorProps = [
'text-gray-500', 'text-gray-500',
@ -57,7 +61,7 @@ export default function ModelMenu() {
<DropdownMenuSeparator /> <DropdownMenuSeparator />
<DropdownMenuRadioGroup <DropdownMenuRadioGroup
value={model} value={model}
onValueChange={setModel} onValueChange={onChange}
> >
<DropdownMenuRadioItem value="chatgpt">ChatGPT</DropdownMenuRadioItem> <DropdownMenuRadioItem value="chatgpt">ChatGPT</DropdownMenuRadioItem>
<DropdownMenuRadioItem value="davinci">Davinci</DropdownMenuRadioItem> <DropdownMenuRadioItem value="davinci">Davinci</DropdownMenuRadioItem>

View file

@ -15,7 +15,7 @@ export default function TextChat({ messages }) {
const [errorMessage, setErrorMessage] = useState(''); const [errorMessage, setErrorMessage] = useState('');
const dispatch = useDispatch(); const dispatch = useDispatch();
const convo = useSelector((state) => state.convo); const convo = useSelector((state) => state.convo);
const { isSubmitting } = useSelector((state) => state.submit); const { isSubmitting, model } = useSelector((state) => state.submit);
const { text } = useSelector((state) => state.text); const { text } = useSelector((state) => state.text);
const { error } = convo; const { error } = convo;
@ -28,8 +28,8 @@ export default function TextChat({ messages }) {
return; return;
} }
dispatch(setSubmitState(true)); dispatch(setSubmitState(true));
const payload = text.trim(); const message = text.trim();
const currentMsg = { sender: 'User', text: payload, current: true }; const currentMsg = { sender: 'User', text: message, current: true };
const initialResponse = { sender: 'GPT', text: '' }; const initialResponse = { sender: 'GPT', text: '' };
dispatch(setMessages([...messages, currentMsg, initialResponse])); dispatch(setMessages([...messages, currentMsg, initialResponse]));
dispatch(setText('')); dispatch(setText(''));
@ -59,12 +59,20 @@ export default function TextChat({ messages }) {
setErrorMessage(event.data); setErrorMessage(event.data);
dispatch(setSubmitState(false)); dispatch(setSubmitState(false));
dispatch(setMessages([...messages.slice(0, -2), currentMsg, errorResponse])); dispatch(setMessages([...messages.slice(0, -2), currentMsg, errorResponse]));
dispatch(setText(payload)); dispatch(setText(message));
dispatch(setError(true)); dispatch(setError(true));
return; return;
}; };
console.log('User Input:', payload); const submission = {
handleSubmit({ text: payload, messageHandler, convo, convoHandler, errorHandler }); model,
text: message,
convo,
messageHandler,
convoHandler,
errorHandler
};
console.log('User Input:', message);
handleSubmit(submission);
}; };
const handleKeyDown = (e) => { const handleKeyDown = (e) => {

View file

@ -19,7 +19,6 @@ const currentSlice = createSlice({
}, },
} }
}); });
//
export const { setConversation, setError } = currentSlice.actions; export const { setConversation, setError } = currentSlice.actions;

View file

@ -2,6 +2,7 @@ import { createSlice } from '@reduxjs/toolkit';
const initialState = { const initialState = {
isSubmitting: false, isSubmitting: false,
model: 'chatgpt'
}; };
const currentSlice = createSlice({ const currentSlice = createSlice({
@ -11,9 +12,12 @@ const currentSlice = createSlice({
setSubmitState: (state, action) => { setSubmitState: (state, action) => {
state.isSubmitting = action.payload; state.isSubmitting = action.payload;
}, },
setModel: (state, action) => {
state.model = action.payload;
},
} }
}); });
export const { setSubmitState } = currentSlice.actions; export const { setSubmitState, setModel } = currentSlice.actions;
export default currentSlice.reducer; export default currentSlice.reducer;

View file

@ -1,13 +1,14 @@
import { SSE } from '../../app/sse'; import { SSE } from '../../app/sse';
export default function handleSubmit({ export default function handleSubmit({
model,
text, text,
convo, convo,
messageHandler, messageHandler,
convoHandler, convoHandler,
errorHandler errorHandler
}) { }) {
let payload = { text }; let payload = { model, text };
if (convo.conversationId && convo.parentMessageId) { if (convo.conversationId && convo.parentMessageId) {
payload = { payload = {
...payload, ...payload,