feat: Implement WebRTC messaging and audio handling in the WebRTC service

This commit is contained in:
Marco Beretta 2024-12-21 16:18:23 +01:00
parent cf4b73b5e3
commit 9a33292f88
No known key found for this signature in database
GPG key ID: D918033D8E74CC11
8 changed files with 674 additions and 137 deletions

View file

@ -56,6 +56,17 @@ export type BadgeItem = {
isAvailable: boolean;
};
export interface RTCMessage {
type:
| 'audio-chunk'
| 'audio-received'
| 'transcription'
| 'llm-response'
| 'tts-chunk'
| 'call-ended';
data?: string | ArrayBuffer | null;
}
export type AssistantListItem = {
id: string;
name: string;

View file

@ -1,67 +1,106 @@
import { useState, useRef, useCallback } from 'react';
import useWebSocket from './useWebSocket';
import { WebRTCService } from '../services/WebRTC/WebRTCService';
import type { RTCMessage } from '~/common';
import useWebSocket from './useWebSocket';
const SILENCE_THRESHOLD = -50;
const SILENCE_DURATION = 1000;
const useCall = () => {
const { sendMessage } = useWebSocket();
const { sendMessage: wsMessage } = useWebSocket();
const [isCalling, setIsCalling] = useState(false);
const [isProcessing, setIsProcessing] = useState(false);
const audioContextRef = useRef<AudioContext | null>(null);
const analyserRef = useRef<AnalyserNode | null>(null);
const audioChunksRef = useRef<Blob[]>([]);
const silenceStartRef = useRef<number | null>(null);
const intervalRef = useRef<number | null>(null);
const webrtcServiceRef = useRef<WebRTCService | null>(null);
const checkSilence = useCallback(() => {
if (!analyserRef.current || !isCalling) {
const sendAudioChunk = useCallback(() => {
if (audioChunksRef.current.length === 0) {
return;
}
const data = new Float32Array(analyserRef.current.frequencyBinCount);
analyserRef.current.getFloatFrequencyData(data);
const avg = data.reduce((a, b) => a + b) / data.length;
if (avg < SILENCE_THRESHOLD) {
if (!silenceStartRef.current) {
silenceStartRef.current = Date.now();
} else if (Date.now() - silenceStartRef.current > SILENCE_DURATION) {
sendMessage({ type: 'request-response' });
silenceStartRef.current = null;
}
} else {
silenceStartRef.current = null;
const audioBlob = new Blob(audioChunksRef.current, { type: 'audio/webm' });
// Send audio through WebRTC data channel
webrtcServiceRef.current?.sendAudioChunk(audioBlob);
// Signal processing start via WebSocket
wsMessage({ type: 'processing-start' });
audioChunksRef.current = [];
setIsProcessing(true);
}, [wsMessage]);
const handleRTCMessage = useCallback((message: RTCMessage) => {
if (message.type === 'audio-received') {
// Backend confirmed audio receipt
setIsProcessing(true);
}
}, [isCalling, sendMessage]);
}, []);
const startCall = useCallback(async () => {
webrtcServiceRef.current = new WebRTCService(sendMessage);
// Initialize WebRTC with message handler
webrtcServiceRef.current = new WebRTCService(handleRTCMessage);
await webrtcServiceRef.current.initializeCall();
// Signal call start via WebSocket
wsMessage({ type: 'call-start' });
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContextRef.current = new AudioContext();
const source = audioContextRef.current.createMediaStreamSource(stream);
analyserRef.current = audioContextRef.current.createAnalyser();
source.connect(analyserRef.current);
intervalRef.current = window.setInterval(checkSilence, 100);
// Start VAD monitoring
intervalRef.current = window.setInterval(() => {
if (!analyserRef.current || !isCalling) {
return;
}
const data = new Float32Array(analyserRef.current.frequencyBinCount);
analyserRef.current.getFloatFrequencyData(data);
const avg = data.reduce((a, b) => a + b) / data.length;
if (avg < SILENCE_THRESHOLD) {
if (silenceStartRef.current === null) {
silenceStartRef.current = Date.now();
} else if (Date.now() - silenceStartRef.current > SILENCE_DURATION) {
sendAudioChunk();
silenceStartRef.current = null;
}
} else {
silenceStartRef.current = null;
}
}, 100);
setIsCalling(true);
}, [checkSilence, sendMessage]);
}, [handleRTCMessage, wsMessage, sendAudioChunk]);
const hangUp = useCallback(async () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
}
analyserRef.current = null;
audioContextRef.current?.close();
audioContextRef.current = null;
await webrtcServiceRef.current?.endCall();
webrtcServiceRef.current = null;
setIsCalling(false);
sendMessage({ type: 'call-ended' });
}, [sendMessage]);
return { isCalling, startCall, hangUp };
setIsCalling(false);
setIsProcessing(false);
wsMessage({ type: 'call-ended' });
}, [wsMessage]);
return {
isCalling,
isProcessing,
startCall,
hangUp,
};
};
export default useCall;

View file

@ -1,75 +1,44 @@
import { useRef, useCallback } from 'react';
import { WebRTCService } from '../services/WebRTC/WebRTCService';
import type { RTCMessage } from '~/common';
import useWebSocket from './useWebSocket';
const SILENCE_THRESHOLD = -50;
const SILENCE_DURATION = 1000;
const useWebRTC = () => {
const { sendMessage } = useWebSocket();
const localStreamRef = useRef<MediaStream | null>(null);
const audioContextRef = useRef<AudioContext | null>(null);
const analyserRef = useRef<AnalyserNode | null>(null);
const silenceStartTime = useRef<number | null>(null);
const isProcessingRef = useRef(false);
const webrtcServiceRef = useRef<WebRTCService | null>(null);
const log = (msg: string) => console.log(`[WebRTC ${new Date().toISOString()}] ${msg}`);
const processAudioLevel = () => {
if (!analyserRef.current || !isProcessingRef.current) {
return;
}
const dataArray = new Float32Array(analyserRef.current.frequencyBinCount);
analyserRef.current.getFloatFrequencyData(dataArray);
const average = dataArray.reduce((a, b) => a + b) / dataArray.length;
if (average < SILENCE_THRESHOLD) {
if (!silenceStartTime.current) {
silenceStartTime.current = Date.now();
log(`Silence started: ${average}dB`);
} else if (Date.now() - silenceStartTime.current > SILENCE_DURATION) {
log('Silence threshold reached - requesting response');
sendMessage({ type: 'request-response' });
silenceStartTime.current = null;
const handleRTCMessage = useCallback(
(message: RTCMessage) => {
switch (message.type) {
case 'audio-chunk':
sendMessage({ type: 'processing-start' });
break;
case 'transcription':
case 'llm-response':
case 'tts-chunk':
// TODO: Handle streaming responses
break;
}
} else {
silenceStartTime.current = null;
}
requestAnimationFrame(processAudioLevel);
};
},
[sendMessage],
);
const startLocalStream = async () => {
try {
log('Starting audio capture');
localStreamRef.current = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContextRef.current = new AudioContext();
const source = audioContextRef.current.createMediaStreamSource(localStreamRef.current);
analyserRef.current = audioContextRef.current.createAnalyser();
source.connect(analyserRef.current);
isProcessingRef.current = true;
processAudioLevel();
log('Audio capture started');
webrtcServiceRef.current = new WebRTCService(handleRTCMessage);
await webrtcServiceRef.current.initializeCall();
sendMessage({ type: 'call-start' });
} catch (error) {
log(`Error: ${error instanceof Error ? error.message : 'Unknown error'}`);
console.error(error);
throw error;
}
};
const stopLocalStream = useCallback(() => {
log('Stopping audio capture');
isProcessingRef.current = false;
audioContextRef.current?.close();
localStreamRef.current?.getTracks().forEach((track) => track.stop());
localStreamRef.current = null;
audioContextRef.current = null;
analyserRef.current = null;
silenceStartTime.current = null;
}, []);
webrtcServiceRef.current?.endCall();
webrtcServiceRef.current = null;
sendMessage({ type: 'call-ended' });
}, [sendMessage]);
return { startLocalStream, stopLocalStream };
};

View file

@ -1,39 +1,47 @@
import { useEffect, useRef, useState, useCallback } from 'react';
import { useGetWebsocketUrlQuery } from 'librechat-data-provider/react-query';
import type { RTCMessage } from '~/common';
const useWebSocket = () => {
const { data: url } = useGetWebsocketUrlQuery();
const { data: data } = useGetWebsocketUrlQuery();
const [isConnected, setIsConnected] = useState(false);
const wsRef = useRef<WebSocket | null>(null);
console.log('wsConfig:', url?.url);
const connect = useCallback(() => {
if (!url?.url) {
if (!data || !data.url) {
return;
}
wsRef.current = new WebSocket(url?.url);
wsRef.current = new WebSocket(data.url);
wsRef.current.onopen = () => setIsConnected(true);
wsRef.current.onclose = () => setIsConnected(false);
wsRef.current.onerror = (err) => console.error('WebSocket error:', err);
wsRef.current.onmessage = (event) => {
const msg = JSON.parse(event.data);
if (msg.type === 'audio-response') {
const audioData = msg.data;
const audio = new Audio(`data:audio/mp3;base64,${audioData}`);
audio.play().catch(console.error);
const msg: RTCMessage = JSON.parse(event.data);
switch (msg.type) {
case 'transcription':
// TODO: Handle transcription update
break;
case 'llm-response':
// TODO: Handle LLM streaming response
break;
case 'tts-chunk':
if (typeof msg.data === 'string') {
const audio = new Audio(`data:audio/mp3;base64,${msg.data}`);
audio.play().catch(console.error);
}
break;
}
};
}, [url?.url]);
}, [data?.url]);
useEffect(() => {
connect();
return () => wsRef.current?.close();
}, [connect]);
const sendMessage = useCallback((message: any) => {
const sendMessage = useCallback((message: Record<string, unknown>) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify(message));
}

View file

@ -1,36 +1,55 @@
import type { RTCMessage } from '~/common';
export class WebRTCService {
private peerConnection: RTCPeerConnection | null = null;
private dataChannel: RTCDataChannel | null = null;
private mediaRecorder: MediaRecorder | null = null;
private sendMessage: (msg: any) => void;
private onMessage: (msg: RTCMessage) => void;
constructor(sendMessage: (msg: any) => void) {
this.sendMessage = sendMessage;
constructor(onMessage: (msg: RTCMessage) => void) {
this.onMessage = onMessage;
}
async initializeCall() {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
this.peerConnection = new RTCPeerConnection();
stream.getTracks().forEach((track) => this.peerConnection?.addTrack(track, stream));
this.dataChannel = this.peerConnection.createDataChannel('audio');
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
this.mediaRecorder = new MediaRecorder(stream);
this.mediaRecorder.ondataavailable = (e) => {
if (e.data.size > 0) {
const reader = new FileReader();
reader.onload = () => {
this.sendMessage({
type: 'audio-chunk',
data: reader.result,
});
};
reader.readAsDataURL(e.data);
if (e.data.size > 0 && this.dataChannel?.readyState === 'open') {
e.data.arrayBuffer().then((buffer) => {
this.dataChannel?.send(buffer);
});
}
};
this.mediaRecorder.start();
this.mediaRecorder.start(100);
this.setupDataChannel();
}
private setupDataChannel() {
if (!this.dataChannel) {
return;
}
this.dataChannel.onmessage = (event) => {
this.onMessage({
type: 'audio-chunk',
data: event.data,
});
};
}
public async sendAudioChunk(audioBlob: Blob) {
if (this.dataChannel && this.dataChannel.readyState === 'open') {
this.dataChannel.send(await audioBlob.arrayBuffer());
}
}
async endCall() {
this.mediaRecorder?.stop();
this.dataChannel?.close();
this.peerConnection?.close();
this.peerConnection = null;
}
}