feat: Add WebSocket functionality and integrate call features in the chat component

This commit is contained in:
Marco Beretta 2024-12-21 14:36:01 +01:00
parent ea5cb4bc2b
commit cf4b73b5e3
No known key found for this signature in database
GPG key ID: D918033D8E74CC11
21 changed files with 588 additions and 142 deletions

View file

@ -0,0 +1,52 @@
import { useRecoilState } from 'recoil';
import { Mic, Phone, PhoneOff } from 'lucide-react';
import { OGDialog, OGDialogContent, Button } from '~/components';
import { useWebRTC, useWebSocket, useCall } from '~/hooks';
import store from '~/store';
export const Call: React.FC = () => {
const { isConnected } = useWebSocket();
const { isCalling, startCall, hangUp } = useCall();
const [open, setOpen] = useRecoilState(store.callDialogOpen(0));
return (
<OGDialog open={open} onOpenChange={setOpen}>
<OGDialogContent className="w-96 p-8">
<div className="flex flex-col items-center gap-6">
<div
className={`flex items-center gap-2 rounded-full px-4 py-2 ${
isConnected ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700'
}`}
>
<div
className={`h-2 w-2 rounded-full ${isConnected ? 'bg-green-500' : 'bg-red-500'}`}
/>
<span className="text-sm font-medium">
{isConnected ? 'Connected' : 'Disconnected'}
</span>
</div>
{!isCalling ? (
<Button
onClick={startCall}
disabled={!isConnected}
className="flex items-center gap-2 rounded-full bg-green-500 px-6 py-3 text-white hover:bg-green-600 disabled:opacity-50"
>
<Phone size={20} />
<span>Start Call</span>
</Button>
) : (
<Button
onClick={hangUp}
className="flex items-center gap-2 rounded-full bg-red-500 px-6 py-3 text-white hover:bg-red-600"
>
<PhoneOff size={20} />
<span>End Call</span>
</Button>
)}
</div>
</OGDialogContent>
</OGDialog>
);
};

View file

@ -26,12 +26,12 @@ import PromptsCommand from './PromptsCommand';
import AudioRecorder from './AudioRecorder';
import CollapseChat from './CollapseChat';
import StreamAudio from './StreamAudio';
import CallButton from './CallButton';
import StopButton from './StopButton';
import SendButton from './SendButton';
import { BadgeRow } from './BadgeRow';
import EditBadges from './EditBadges';
import Mention from './Mention';
import { Call } from './Call';
import store from '~/store';
const ChatForm = memo(({ index = 0 }: { index?: number }) => {
@ -190,140 +190,145 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
);
return (
<form
onSubmit={methods.handleSubmit(submitMessage)}
className={cn(
'mx-auto flex flex-row gap-3 sm:px-2',
maximizeChatSpace ? 'w-full max-w-full' : 'md:max-w-3xl xl:max-w-4xl',
centerFormOnLanding &&
(!conversation?.conversationId || conversation?.conversationId === Constants.NEW_CONVO) &&
!isSubmitting
? 'transition-all duration-200 sm:mb-28'
: 'sm:mb-10',
)}
>
<div className="relative flex h-full flex-1 items-stretch md:flex-col">
<div className={cn('flex w-full items-center', isRTL && 'flex-row-reverse')}>
{showPlusPopover && !isAssistantsEndpoint(endpoint) && (
<Mention
setShowMentionPopover={setShowPlusPopover}
newConversation={generateConversation}
textAreaRef={textAreaRef}
commandChar="+"
placeholder="com_ui_add_model_preset"
includeAssistants={false}
/>
)}
{showMentionPopover && (
<Mention
setShowMentionPopover={setShowMentionPopover}
newConversation={newConversation}
textAreaRef={textAreaRef}
/>
)}
<PromptsCommand index={index} textAreaRef={textAreaRef} submitPrompt={submitPrompt} />
<div
onClick={handleContainerClick}
className={cn(
'relative flex w-full flex-grow flex-col overflow-hidden rounded-t-3xl border pb-4 text-text-primary transition-all duration-200 sm:rounded-3xl sm:pb-0',
isTextAreaFocused ? 'shadow-lg' : 'shadow-md',
isTemporary
? 'border-violet-800/60 bg-violet-950/10'
: 'border-border-light bg-surface-chat',
<>
<form
onSubmit={methods.handleSubmit(submitMessage)}
className={cn(
'mx-auto flex flex-row gap-3 sm:px-2',
maximizeChatSpace ? 'w-full max-w-full' : 'md:max-w-3xl xl:max-w-4xl',
centerFormOnLanding &&
(!conversation?.conversationId ||
conversation?.conversationId === Constants.NEW_CONVO) &&
!isSubmitting
? 'transition-all duration-200 sm:mb-28'
: 'sm:mb-10',
)}
>
<div className="relative flex h-full flex-1 items-stretch md:flex-col">
<div className={cn('flex w-full items-center', isRTL && 'flex-row-reverse')}>
{showPlusPopover && !isAssistantsEndpoint(endpoint) && (
<Mention
setShowMentionPopover={setShowPlusPopover}
newConversation={generateConversation}
textAreaRef={textAreaRef}
commandChar="+"
placeholder="com_ui_add_model_preset"
includeAssistants={false}
/>
)}
>
<TextareaHeader addedConvo={addedConvo} setAddedConvo={setAddedConvo} />
<EditBadges
isEditingChatBadges={isEditingBadges}
handleCancelBadges={handleCancelBadges}
handleSaveBadges={handleSaveBadges}
setBadges={setBadges}
/>
<FileFormChat disableInputs={disableInputs} />
{endpoint && (
<div className={cn('flex', isRTL ? 'flex-row-reverse' : 'flex-row')}>
<TextareaAutosize
{...registerProps}
ref={(e) => {
ref(e);
(textAreaRef as React.MutableRefObject<HTMLTextAreaElement | null>).current = e;
}}
disabled={disableInputs}
onPaste={handlePaste}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
onCompositionStart={handleCompositionStart}
onCompositionEnd={handleCompositionEnd}
id={mainTextareaId}
tabIndex={0}
data-testid="text-input"
rows={1}
onFocus={() => {
handleFocusOrClick();
setIsTextAreaFocused(true);
}}
onBlur={setIsTextAreaFocused.bind(null, false)}
onClick={handleFocusOrClick}
style={{ height: 44, overflowY: 'auto' }}
className={cn(
baseClasses,
removeFocusRings,
'transition-[max-height] duration-200',
)}
/>
<div className="flex flex-col items-start justify-start pt-1.5">
<CollapseChat
isCollapsed={isCollapsed}
isScrollable={isMoreThanThreeRows}
setIsCollapsed={setIsCollapsed}
/>
</div>
</div>
{showMentionPopover && (
<Mention
setShowMentionPopover={setShowMentionPopover}
newConversation={newConversation}
textAreaRef={textAreaRef}
/>
)}
<PromptsCommand index={index} textAreaRef={textAreaRef} submitPrompt={submitPrompt} />
<div
onClick={handleContainerClick}
className={cn(
'items-between flex gap-2 pb-2',
isRTL ? 'flex-row-reverse' : 'flex-row',
'relative flex w-full flex-grow flex-col overflow-hidden rounded-t-3xl border pb-4 text-text-primary transition-all duration-200 sm:rounded-3xl sm:pb-0',
isTextAreaFocused ? 'shadow-lg' : 'shadow-md',
isTemporary
? 'border-violet-800/60 bg-violet-950/10'
: 'border-border-light bg-surface-chat',
)}
>
<div className={`${isRTL ? 'mr-2' : 'ml-2'}`}>
<AttachFileChat disableInputs={disableInputs} />
</div>
<BadgeRow
onChange={(newBadges) => setBadges(newBadges)}
isInChat={
Array.isArray(conversation?.messages) && conversation.messages.length >= 1
}
<TextareaHeader addedConvo={addedConvo} setAddedConvo={setAddedConvo} />
<EditBadges
isEditingChatBadges={isEditingBadges}
handleCancelBadges={handleCancelBadges}
handleSaveBadges={handleSaveBadges}
setBadges={setBadges}
/>
<div className="mx-auto flex" />
{SpeechToText && (
<AudioRecorder
methods={methods}
ask={submitMessage}
textAreaRef={textAreaRef}
disabled={disableInputs}
isSubmitting={isSubmitting}
/>
)}
<div className={`${isRTL ? 'ml-2' : 'mr-2'}`}>
{(isSubmitting || isSubmittingAdded) && (showStopButton || showStopAdded) ? (
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
) : (
endpoint && (
<SendButton
ref={submitButtonRef}
control={methods.control}
disabled={filesLoading || isSubmitting || disableInputs}
<FileFormChat disableInputs={disableInputs} />
{endpoint && (
<div className={cn('flex', isRTL ? 'flex-row-reverse' : 'flex-row')}>
<TextareaAutosize
{...registerProps}
ref={(e) => {
ref(e);
(textAreaRef as React.MutableRefObject<HTMLTextAreaElement | null>).current =
e;
}}
disabled={disableInputs}
onPaste={handlePaste}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
onCompositionStart={handleCompositionStart}
onCompositionEnd={handleCompositionEnd}
id={mainTextareaId}
tabIndex={0}
data-testid="text-input"
rows={1}
onFocus={() => {
handleFocusOrClick();
setIsTextAreaFocused(true);
}}
onBlur={setIsTextAreaFocused.bind(null, false)}
onClick={handleFocusOrClick}
style={{ height: 44, overflowY: 'auto' }}
className={cn(
baseClasses,
removeFocusRings,
'transition-[max-height] duration-200',
)}
/>
<div className="flex flex-col items-start justify-start pt-1.5">
<CollapseChat
isCollapsed={isCollapsed}
isScrollable={isMoreThanThreeRows}
setIsCollapsed={setIsCollapsed}
/>
)
</div>
</div>
)}
<div
className={cn(
'items-between flex gap-2 pb-2',
isRTL ? 'flex-row-reverse' : 'flex-row',
)}
>
<div className={`${isRTL ? 'mr-2' : 'ml-2'}`}>
<AttachFileChat disableInputs={disableInputs} />
</div>
<BadgeRow
onChange={(newBadges) => setBadges(newBadges)}
isInChat={
Array.isArray(conversation?.messages) && conversation.messages.length >= 1
}
/>
<div className="mx-auto flex" />
{SpeechToText && (
<AudioRecorder
methods={methods}
ask={submitMessage}
textAreaRef={textAreaRef}
disabled={disableInputs}
isSubmitting={isSubmitting}
/>
)}
<div className={`${isRTL ? 'ml-2' : 'mr-2'}`}>
{(isSubmitting || isSubmittingAdded) && (showStopButton || showStopAdded) ? (
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
) : (
endpoint && (
<SendButton
ref={submitButtonRef}
control={methods.control}
disabled={filesLoading || isSubmitting || disableInputs}
/>
)
)}
</div>
</div>
{TextToSpeech && automaticPlayback && <StreamAudio index={index} />}
</div>
{TextToSpeech && automaticPlayback && <StreamAudio index={index} />}
</div>
</div>
</div>
</form>
</form>
<Call />
</>
);
});

View file

@ -1,11 +1,13 @@
import React, { forwardRef } from 'react';
import { useWatch } from 'react-hook-form';
import { useSetRecoilState } from 'recoil';
import type { TRealtimeEphemeralTokenResponse } from 'librechat-data-provider';
import type { Control } from 'react-hook-form';
import { useRealtimeEphemeralTokenMutation } from '~/data-provider';
import { TooltipAnchor, SendIcon, CallIcon } from '~/components';
import { useToastContext } from '~/Providers/ToastContext';
import { useLocalize } from '~/hooks';
import store from '~/store';
import { cn } from '~/utils';
type ButtonProps = {
@ -56,25 +58,27 @@ const SendButton = forwardRef((props: ButtonProps, ref: React.ForwardedRef<HTMLB
const localize = useLocalize();
const { showToast } = useToastContext();
const { text = '' } = useWatch({ control: props.control });
const setCallOpen = useSetRecoilState(store.callDialogOpen(0));
const { mutate: startCall, isLoading: isProcessing } = useRealtimeEphemeralTokenMutation({
onSuccess: async (data: TRealtimeEphemeralTokenResponse) => {
showToast({
message: 'IT WORKS!!',
status: 'success',
});
},
onError: (error: unknown) => {
showToast({
message: localize('com_nav_audio_process_error', (error as Error).message),
status: 'error',
});
},
});
// const { mutate: startCall, isLoading: isProcessing } = useRealtimeEphemeralTokenMutation({
// onSuccess: async (data: TRealtimeEphemeralTokenResponse) => {
// showToast({
// message: 'IT WORKS!!',
// status: 'success',
// });
// },
// onError: (error: unknown) => {
// showToast({
// message: localize('com_nav_audio_process_error', (error as Error).message),
// status: 'error',
// });
// },
// });
const handleClick = () => {
if (text.trim() === '') {
startCall({ voice: 'verse' });
setCallOpen(true);
// startCall({ voice: 'verse' });
}
};

View file

@ -21,10 +21,13 @@ export * from './Endpoint';
export type { TranslationKeys } from './useLocalize';
export { default as useCall } from './useCall';
export { default as useToast } from './useToast';
export { default as useWebRTC } from './useWebRTC';
export { default as useTimeout } from './useTimeout';
export { default as useNewConvo } from './useNewConvo';
export { default as useLocalize } from './useLocalize';
export { default as useWebSocket } from './useWebSocket';
export { default as useMediaQuery } from './useMediaQuery';
export { default as useChatBadges } from './useChatBadges';
export { default as useScrollToRef } from './useScrollToRef';

View file

@ -0,0 +1,67 @@
import { useState, useRef, useCallback } from 'react';
import useWebSocket from './useWebSocket';
import { WebRTCService } from '../services/WebRTC/WebRTCService';
const SILENCE_THRESHOLD = -50;
const SILENCE_DURATION = 1000;
const useCall = () => {
const { sendMessage } = useWebSocket();
const [isCalling, setIsCalling] = useState(false);
const audioContextRef = useRef<AudioContext | null>(null);
const analyserRef = useRef<AnalyserNode | null>(null);
const silenceStartRef = useRef<number | null>(null);
const intervalRef = useRef<number | null>(null);
const webrtcServiceRef = useRef<WebRTCService | null>(null);
const checkSilence = useCallback(() => {
if (!analyserRef.current || !isCalling) {
return;
}
const data = new Float32Array(analyserRef.current.frequencyBinCount);
analyserRef.current.getFloatFrequencyData(data);
const avg = data.reduce((a, b) => a + b) / data.length;
if (avg < SILENCE_THRESHOLD) {
if (!silenceStartRef.current) {
silenceStartRef.current = Date.now();
} else if (Date.now() - silenceStartRef.current > SILENCE_DURATION) {
sendMessage({ type: 'request-response' });
silenceStartRef.current = null;
}
} else {
silenceStartRef.current = null;
}
}, [isCalling, sendMessage]);
const startCall = useCallback(async () => {
webrtcServiceRef.current = new WebRTCService(sendMessage);
await webrtcServiceRef.current.initializeCall();
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContextRef.current = new AudioContext();
const source = audioContextRef.current.createMediaStreamSource(stream);
analyserRef.current = audioContextRef.current.createAnalyser();
source.connect(analyserRef.current);
intervalRef.current = window.setInterval(checkSilence, 100);
setIsCalling(true);
}, [checkSilence, sendMessage]);
const hangUp = useCallback(async () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
}
analyserRef.current = null;
audioContextRef.current?.close();
audioContextRef.current = null;
await webrtcServiceRef.current?.endCall();
webrtcServiceRef.current = null;
setIsCalling(false);
sendMessage({ type: 'call-ended' });
}, [sendMessage]);
return { isCalling, startCall, hangUp };
};
export default useCall;

View file

@ -0,0 +1,77 @@
import { useRef, useCallback } from 'react';
import useWebSocket from './useWebSocket';
const SILENCE_THRESHOLD = -50;
const SILENCE_DURATION = 1000;
const useWebRTC = () => {
const { sendMessage } = useWebSocket();
const localStreamRef = useRef<MediaStream | null>(null);
const audioContextRef = useRef<AudioContext | null>(null);
const analyserRef = useRef<AnalyserNode | null>(null);
const silenceStartTime = useRef<number | null>(null);
const isProcessingRef = useRef(false);
const log = (msg: string) => console.log(`[WebRTC ${new Date().toISOString()}] ${msg}`);
const processAudioLevel = () => {
if (!analyserRef.current || !isProcessingRef.current) {
return;
}
const dataArray = new Float32Array(analyserRef.current.frequencyBinCount);
analyserRef.current.getFloatFrequencyData(dataArray);
const average = dataArray.reduce((a, b) => a + b) / dataArray.length;
if (average < SILENCE_THRESHOLD) {
if (!silenceStartTime.current) {
silenceStartTime.current = Date.now();
log(`Silence started: ${average}dB`);
} else if (Date.now() - silenceStartTime.current > SILENCE_DURATION) {
log('Silence threshold reached - requesting response');
sendMessage({ type: 'request-response' });
silenceStartTime.current = null;
}
} else {
silenceStartTime.current = null;
}
requestAnimationFrame(processAudioLevel);
};
const startLocalStream = async () => {
try {
log('Starting audio capture');
localStreamRef.current = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContextRef.current = new AudioContext();
const source = audioContextRef.current.createMediaStreamSource(localStreamRef.current);
analyserRef.current = audioContextRef.current.createAnalyser();
source.connect(analyserRef.current);
isProcessingRef.current = true;
processAudioLevel();
log('Audio capture started');
} catch (error) {
log(`Error: ${error instanceof Error ? error.message : 'Unknown error'}`);
throw error;
}
};
const stopLocalStream = useCallback(() => {
log('Stopping audio capture');
isProcessingRef.current = false;
audioContextRef.current?.close();
localStreamRef.current?.getTracks().forEach((track) => track.stop());
localStreamRef.current = null;
audioContextRef.current = null;
analyserRef.current = null;
silenceStartTime.current = null;
}, []);
return { startLocalStream, stopLocalStream };
};
export default useWebRTC;

View file

@ -0,0 +1,45 @@
import { useEffect, useRef, useState, useCallback } from 'react';
import { useGetWebsocketUrlQuery } from 'librechat-data-provider/react-query';
const useWebSocket = () => {
const { data: url } = useGetWebsocketUrlQuery();
const [isConnected, setIsConnected] = useState(false);
const wsRef = useRef<WebSocket | null>(null);
console.log('wsConfig:', url?.url);
const connect = useCallback(() => {
if (!url?.url) {
return;
}
wsRef.current = new WebSocket(url?.url);
wsRef.current.onopen = () => setIsConnected(true);
wsRef.current.onclose = () => setIsConnected(false);
wsRef.current.onerror = (err) => console.error('WebSocket error:', err);
wsRef.current.onmessage = (event) => {
const msg = JSON.parse(event.data);
if (msg.type === 'audio-response') {
const audioData = msg.data;
const audio = new Audio(`data:audio/mp3;base64,${audioData}`);
audio.play().catch(console.error);
}
};
}, [url?.url]);
useEffect(() => {
connect();
return () => wsRef.current?.close();
}, [connect]);
const sendMessage = useCallback((message: any) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify(message));
}
}, []);
return { isConnected, sendMessage };
};
export default useWebSocket;

View file

@ -0,0 +1,36 @@
export class WebRTCService {
private peerConnection: RTCPeerConnection | null = null;
private mediaRecorder: MediaRecorder | null = null;
private sendMessage: (msg: any) => void;
constructor(sendMessage: (msg: any) => void) {
this.sendMessage = sendMessage;
}
async initializeCall() {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
this.peerConnection = new RTCPeerConnection();
stream.getTracks().forEach((track) => this.peerConnection?.addTrack(track, stream));
this.mediaRecorder = new MediaRecorder(stream);
this.mediaRecorder.ondataavailable = (e) => {
if (e.data.size > 0) {
const reader = new FileReader();
reader.onload = () => {
this.sendMessage({
type: 'audio-chunk',
data: reader.result,
});
};
reader.readAsDataURL(e.data);
}
};
this.mediaRecorder.start();
}
async endCall() {
this.mediaRecorder?.stop();
this.peerConnection?.close();
this.peerConnection = null;
}
}

View file

@ -368,6 +368,11 @@ const updateConversationSelector = selectorFamily({
},
});
const callDialogOpen = atomFamily<boolean, string | number | null>({
key: 'callDialogOpen',
default: false,
});
export default {
conversationKeysAtom,
conversationByIndex,
@ -399,4 +404,5 @@ export default {
useClearLatestMessages,
showPromptsPopoverFamily,
updateConversationSelector,
callDialogOpen,
};