fix: both webrtc-client and webrtc-server

This commit is contained in:
Marco Beretta 2025-01-03 19:35:20 +01:00 committed by Danny Avila
parent 9c0c341dee
commit 964d47cfa3
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
8 changed files with 948 additions and 390 deletions

View file

@ -1,157 +1,205 @@
const { Server } = require('socket.io'); const { Server } = require('socket.io');
const { RTCPeerConnection } = require('wrtc'); const { RTCPeerConnection, RTCIceCandidate } = require('wrtc');
module.exports.SocketIOService = class { class WebRTCConnection {
constructor(httpServer) { constructor(socket, config) {
this.io = new Server(httpServer, { path: '/socket.io' }); this.socket = socket;
this.log('Socket.IO Server initialized'); this.config = config;
this.activeClients = new Map(); this.peerConnection = null;
this.iceServers = [ this.audioTransceiver = null;
{ urls: 'stun:stun.l.google.com:19302' }, this.pendingCandidates = [];
{ urls: 'stun:stun1.l.google.com:19302' }, this.state = 'idle';
]; this.log = config.log || console.log;
this.setupHandlers();
} }
log(msg) { async handleOffer(offer) {
console.log(`[Socket.IO ${new Date().toISOString()}] ${msg}`); try {
// Create new peer connection if needed
if (!this.peerConnection) {
this.peerConnection = new RTCPeerConnection(this.config.rtcConfig);
this.setupPeerConnectionListeners();
} }
setupHandlers() { // Set the remote description (client's offer)
this.io.on('connection', (socket) => { await this.peerConnection.setRemoteDescription(offer);
const clientId = socket.id;
this.activeClients.set(clientId, { // Set up audio transceiver for two-way audio
socket, this.audioTransceiver = this.peerConnection.addTransceiver('audio', {
state: 'idle', direction: 'sendrecv',
audioBuffer: [],
currentTranscription: '',
isProcessing: false,
}); });
this.log(`Client connected: ${clientId}`); // Create and set local description (answer)
const answer = await this.peerConnection.createAnswer();
await this.peerConnection.setLocalDescription(answer);
socket.on('call-start', () => this.handleCallStart(clientId)); // Send answer to client
socket.on('audio-chunk', (data) => this.handleAudioChunk(clientId, data)); this.socket.emit('webrtc-answer', answer);
socket.on('processing-start', () => this.processAudioStream(clientId));
socket.on('audio-received', () => this.confirmAudioReceived(clientId)); // Process any pending ICE candidates
socket.on('call-ended', () => this.handleCallEnd(clientId)); while (this.pendingCandidates.length) {
const candidate = this.pendingCandidates.shift();
await this.addIceCandidate(candidate);
}
this.state = 'connecting';
} catch (error) {
this.log(`Error handling offer: ${error}`, 'error');
this.socket.emit('error', { message: 'Failed to process offer' });
this.cleanup();
}
}
setupPeerConnectionListeners() {
if (!this.peerConnection) {
return;
}
// Handle incoming audio tracks
this.peerConnection.ontrack = ({ track, streams }) => {
this.log(`Received ${track.kind} track from client`);
// For testing: Echo the audio back after a delay
if (track.kind === 'audio') {
this.handleIncomingAudio(track, streams[0]);
}
track.onended = () => {
this.log(`${track.kind} track ended`);
};
};
this.peerConnection.onicecandidate = ({ candidate }) => {
if (candidate) {
this.socket.emit('icecandidate', candidate);
}
};
this.peerConnection.onconnectionstatechange = () => {
if (!this.peerConnection) {
return;
}
const state = this.peerConnection.connectionState;
this.log(`Connection state changed to ${state}`);
this.state = state;
if (state === 'failed' || state === 'closed') {
this.cleanup();
}
};
this.peerConnection.oniceconnectionstatechange = () => {
if (this.peerConnection) {
this.log(`ICE connection state: ${this.peerConnection.iceConnectionState}`);
}
};
}
handleIncomingAudio(inputTrack) {
// For testing: Echo back the input track directly
this.peerConnection.addTrack(inputTrack);
// Log the track info for debugging
this.log(`Audio track added: ${inputTrack.id}, enabled: ${inputTrack.enabled}`);
}
async addIceCandidate(candidate) {
try {
if (this.peerConnection?.remoteDescription) {
if (candidate && candidate.candidate) {
await this.peerConnection.addIceCandidate(new RTCIceCandidate(candidate));
} else {
this.log('Invalid ICE candidate', 'warn');
}
} else {
this.pendingCandidates.push(candidate);
}
} catch (error) {
this.log(`Error adding ICE candidate: ${error}`, 'error');
}
}
cleanup() {
if (this.peerConnection) {
try {
this.peerConnection.close();
} catch (error) {
this.log(`Error closing peer connection: ${error}`, 'error');
}
this.peerConnection = null;
}
this.audioTransceiver = null;
this.pendingCandidates = [];
this.state = 'idle';
}
}
class SocketIOService {
constructor(httpServer, config = {}) {
this.config = {
rtcConfig: {
iceServers: [
{
urls: ['stun:stun.l.google.com:19302', 'stun:stun1.l.google.com:19302'],
},
],
iceCandidatePoolSize: 10,
bundlePolicy: 'max-bundle',
rtcpMuxPolicy: 'require',
},
...config,
};
this.io = new Server(httpServer, {
path: '/socket.io',
cors: {
origin: '*',
methods: ['GET', 'POST'],
},
});
this.connections = new Map();
this.setupSocketHandlers();
}
log(message, level = 'info') {
const timestamp = new Date().toISOString();
console.log(`[WebRTC ${timestamp}] [${level.toUpperCase()}] ${message}`);
}
setupSocketHandlers() {
this.io.on('connection', (socket) => {
this.log(`Client connected: ${socket.id}`);
// Create a new WebRTC connection for this socket
const rtcConnection = new WebRTCConnection(socket, {
...this.config,
log: this.log.bind(this),
});
this.connections.set(socket.id, rtcConnection);
socket.on('webrtc-offer', (offer) => {
this.log(`Received WebRTC offer from ${socket.id}`);
rtcConnection.handleOffer(offer);
});
socket.on('icecandidate', (candidate) => {
rtcConnection.addIceCandidate(candidate);
});
socket.on('disconnect', () => { socket.on('disconnect', () => {
this.handleCallEnd(clientId); this.log(`Client disconnected: ${socket.id}`);
this.activeClients.delete(clientId); rtcConnection.cleanup();
this.log(`Client disconnected: ${clientId}`); this.connections.delete(socket.id);
});
socket.on('error', (error) => {
this.log(`Error for client ${clientId}: ${error.message}`);
this.handleCallEnd(clientId);
}); });
}); });
} }
async handleCallStart(clientId) { shutdown() {
const client = this.activeClients.get(clientId); for (const connection of this.connections.values()) {
if (!client) { connection.cleanup();
return;
} }
this.connections.clear();
try { this.io.close();
client.state = 'active';
client.audioBuffer = [];
client.currentTranscription = '';
client.isProcessing = false;
const peerConnection = new RTCPeerConnection({
iceServers: this.iceServers,
sdpSemantics: 'unified-plan',
});
client.peerConnection = peerConnection;
client.dataChannel = peerConnection.createDataChannel('audio', {
ordered: true,
maxRetransmits: 3,
});
client.dataChannel.onopen = () => this.log(`Data channel opened for ${clientId}`);
client.dataChannel.onmessage = async (event) => {
await this.handleAudioChunk(clientId, event.data);
};
peerConnection.onicecandidate = (event) => {
if (event.candidate) {
client.socket.emit('ice-candidate', { candidate: event.candidate });
} }
}; }
peerConnection.onnegotiationneeded = async () => { module.exports = { SocketIOService };
try {
const offer = await peerConnection.createOffer();
await peerConnection.setLocalDescription(offer);
client.socket.emit('webrtc-offer', { sdp: peerConnection.localDescription });
} catch (error) {
this.log(`Negotiation failed for ${clientId}: ${error}`);
}
};
this.log(`Call started for client ${clientId}`);
} catch (error) {
this.log(`Error starting call for ${clientId}: ${error.message}`);
this.handleCallEnd(clientId);
}
}
async handleAudioChunk(clientId, data) {
const client = this.activeClients.get(clientId);
if (!client || client.state !== 'active') {
return;
}
client.audioBuffer.push(data);
client.socket.emit('audio-received');
}
async processAudioStream(clientId) {
const client = this.activeClients.get(clientId);
if (!client || client.state !== 'active' || client.isProcessing) {
return;
}
client.isProcessing = true;
try {
// Process transcription
client.socket.emit('transcription', { data: 'Processing audio...' });
// Stream LLM response
client.socket.emit('llm-response', { data: 'Processing response...' });
// Stream TTS chunks
client.socket.emit('tts-chunk', { data: 'audio_data_here' });
} catch (error) {
this.log(`Processing error for client ${clientId}: ${error.message}`);
} finally {
client.isProcessing = false;
client.audioBuffer = [];
}
}
confirmAudioReceived(clientId) {
const client = this.activeClients.get(clientId);
if (!client) {
return;
}
client.socket.emit('audio-received', { data: null });
}
handleCallEnd(clientId) {
const client = this.activeClients.get(clientId);
if (!client) {
return;
}
client.state = 'idle';
client.audioBuffer = [];
client.currentTranscription = '';
}
};

View file

@ -55,8 +55,23 @@ export interface RTCMessage {
| 'transcription' | 'transcription'
| 'llm-response' | 'llm-response'
| 'tts-chunk' | 'tts-chunk'
| 'call-ended'; | 'call-ended'
data?: string | ArrayBuffer | null; | 'webrtc-answer'
| 'icecandidate';
payload?: RTCSessionDescriptionInit | RTCIceCandidateInit;
}
export type MessagePayload =
| RTCSessionDescriptionInit
| RTCIceCandidateInit
| Record<string, never>;
export enum CallState {
IDLE = 'idle',
CONNECTING = 'connecting',
ACTIVE = 'active',
ERROR = 'error',
ENDED = 'ended',
} }
export type AssistantListItem = { export type AssistantListItem = {

View file

@ -1,24 +1,60 @@
import React from 'react'; import React, { useEffect, useRef } from 'react';
import { useRecoilState } from 'recoil'; import { useRecoilState } from 'recoil';
import { Phone, PhoneOff } from 'lucide-react'; import {
Phone,
PhoneOff,
AlertCircle,
Mic,
MicOff,
Volume2,
VolumeX,
Activity,
} from 'lucide-react';
import { OGDialog, OGDialogContent, Button } from '~/components'; import { OGDialog, OGDialogContent, Button } from '~/components';
import { useWebSocket, useCall } from '~/hooks'; import { useWebSocket, useCall } from '~/hooks';
import { CallState } from '~/common';
import store from '~/store'; import store from '~/store';
export const Call: React.FC = () => { export const Call: React.FC = () => {
const { isConnected, sendMessage } = useWebSocket(); const { isConnected } = useWebSocket();
const { isCalling, isProcessing, startCall, hangUp } = useCall(); const {
callState,
error,
startCall,
hangUp,
isConnecting,
localStream,
remoteStream,
connectionQuality,
} = useCall();
const [open, setOpen] = useRecoilState(store.callDialogOpen(0)); const [open, setOpen] = useRecoilState(store.callDialogOpen(0));
const [eventLog, setEventLog] = React.useState<string[]>([]); const [eventLog, setEventLog] = React.useState<string[]>([]);
const [isMuted, setIsMuted] = React.useState(false);
const [isAudioEnabled, setIsAudioEnabled] = React.useState(true);
const remoteAudioRef = useRef<HTMLAudioElement>(null);
const logEvent = (message: string) => { const logEvent = (message: string) => {
console.log(message); console.log(message);
setEventLog((prev) => [...prev, message]); setEventLog((prev) => [...prev, `${new Date().toISOString()}: ${message}`]);
}; };
React.useEffect(() => { useEffect(() => {
if (remoteAudioRef.current && remoteStream) {
remoteAudioRef.current.srcObject = remoteStream;
}
}, [remoteStream]);
useEffect(() => {
if (localStream) {
localStream.getAudioTracks().forEach((track) => {
track.enabled = !isMuted;
});
}
}, [localStream, isMuted]);
useEffect(() => {
if (isConnected) { if (isConnected) {
logEvent('Connected to server.'); logEvent('Connected to server.');
} else { } else {
@ -26,15 +62,15 @@ export const Call: React.FC = () => {
} }
}, [isConnected]); }, [isConnected]);
React.useEffect(() => { useEffect(() => {
if (isCalling) { if (error) {
logEvent('Call started.'); logEvent(`Error: ${error.message} (${error.code})`);
} else if (isProcessing) {
logEvent('Processing audio...');
} else {
logEvent('Call ended.');
} }
}, [isCalling, isProcessing]); }, [error]);
useEffect(() => {
logEvent(`Call state changed to: ${callState}`);
}, [callState]);
const handleStartCall = () => { const handleStartCall = () => {
logEvent('Attempting to start call...'); logEvent('Attempting to start call...');
@ -46,10 +82,28 @@ export const Call: React.FC = () => {
hangUp(); hangUp();
}; };
const toggleMute = () => {
setIsMuted((prev) => !prev);
logEvent(`Microphone ${isMuted ? 'unmuted' : 'muted'}`);
};
const toggleAudio = () => {
setIsAudioEnabled((prev) => !prev);
if (remoteAudioRef.current) {
remoteAudioRef.current.muted = !isAudioEnabled;
}
logEvent(`Speaker ${isAudioEnabled ? 'disabled' : 'enabled'}`);
};
const isActive = callState === CallState.ACTIVE;
const isError = callState === CallState.ERROR;
return ( return (
<OGDialog open={open} onOpenChange={setOpen}> <OGDialog open={open} onOpenChange={setOpen}>
<OGDialogContent className="w-96 p-8"> <OGDialogContent className="w-[28rem] p-8">
<div className="flex flex-col items-center gap-6"> <div className="flex flex-col items-center gap-6">
{/* Connection Status */}
<div className="flex w-full items-center justify-between">
<div <div
className={`flex items-center gap-2 rounded-full px-4 py-2 ${ className={`flex items-center gap-2 rounded-full px-4 py-2 ${
isConnected ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700' isConnected ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700'
@ -63,7 +117,55 @@ export const Call: React.FC = () => {
</span> </span>
</div> </div>
{isCalling ? ( {isActive && (
<div
className={`flex items-center gap-2 rounded-full px-4 py-2 ${
(connectionQuality === 'good' && 'bg-green-100 text-green-700') ||
(connectionQuality === 'poor' && 'bg-yellow-100 text-yellow-700') ||
'bg-gray-100 text-gray-700'
}`}
>
<Activity size={16} />
<span className="text-sm font-medium capitalize">{connectionQuality} Quality</span>
</div>
)}
</div>
{/* Error Display */}
{error && (
<div className="flex w-full items-center gap-2 rounded-md bg-red-100 p-3 text-red-700">
<AlertCircle size={16} />
<span className="text-sm">{error.message}</span>
</div>
)}
{/* Call Controls */}
<div className="flex items-center gap-4">
{isActive && (
<>
<Button
onClick={toggleMute}
className={`rounded-full p-3 ${
isMuted ? 'bg-red-100 text-red-700' : 'bg-gray-100 text-gray-700'
}`}
title={isMuted ? 'Unmute microphone' : 'Mute microphone'}
>
{isMuted ? <MicOff size={20} /> : <Mic size={20} />}
</Button>
<Button
onClick={toggleAudio}
className={`rounded-full p-3 ${
!isAudioEnabled ? 'bg-red-100 text-red-700' : 'bg-gray-100 text-gray-700'
}`}
title={isAudioEnabled ? 'Disable speaker' : 'Enable speaker'}
>
{isAudioEnabled ? <Volume2 size={20} /> : <VolumeX size={20} />}
</Button>
</>
)}
{isActive ? (
<Button <Button
onClick={handleHangUp} onClick={handleHangUp}
className="flex items-center gap-2 rounded-full bg-red-500 px-6 py-3 text-white hover:bg-red-600" className="flex items-center gap-2 rounded-full bg-red-500 px-6 py-3 text-white hover:bg-red-600"
@ -74,24 +176,34 @@ export const Call: React.FC = () => {
) : ( ) : (
<Button <Button
onClick={handleStartCall} onClick={handleStartCall}
disabled={!isConnected} disabled={!isConnected || isError || isConnecting}
className="flex items-center gap-2 rounded-full bg-green-500 px-6 py-3 text-white hover:bg-green-600 disabled:opacity-50" className="flex items-center gap-2 rounded-full bg-green-500 px-6 py-3 text-white hover:bg-green-600 disabled:opacity-50"
> >
<Phone size={20} /> <Phone size={20} />
<span>Start Call</span> <span>{isConnecting ? 'Connecting...' : 'Start Call'}</span>
</Button> </Button>
)} )}
</div>
{/* Debugging Information */} {/* Event Log */}
<div className="mt-4 w-full rounded-md bg-gray-100 p-4 shadow-sm"> <div className="mt-4 w-full rounded-md bg-gray-100 p-4 shadow-sm">
<h3 className="mb-2 text-lg font-medium">Event Log</h3> <h3 className="mb-2 text-lg font-medium">Event Log</h3>
<ul className="h-32 overflow-y-auto text-xs text-gray-600"> <div className="h-32 overflow-y-auto rounded-md bg-white p-2 shadow-inner">
<ul className="space-y-1 text-xs text-gray-600">
{eventLog.map((log, index) => ( {eventLog.map((log, index) => (
<li key={index}>{log}</li> <li key={index} className="font-mono">
{log}
</li>
))} ))}
</ul> </ul>
</div> </div>
</div> </div>
{/* Hidden Audio Element */}
<audio ref={remoteAudioRef} autoPlay>
<track kind="captions" />
</audio>
</div>
</OGDialogContent> </OGDialogContent>
</OGDialog> </OGDialog>
); );

View file

@ -20,7 +20,6 @@ export * from './ScreenshotContext';
export * from './ApiErrorBoundaryContext'; export * from './ApiErrorBoundaryContext';
export { default as useCall } from './useCall'; export { default as useCall } from './useCall';
export { default as useToast } from './useToast'; export { default as useToast } from './useToast';
export { default as useWebRTC } from './useWebRTC';
export { default as useTimeout } from './useTimeout'; export { default as useTimeout } from './useTimeout';
export { default as useNewConvo } from './useNewConvo'; export { default as useNewConvo } from './useNewConvo';
export { default as useLocalize } from './useLocalize'; export { default as useLocalize } from './useLocalize';

View file

@ -1,101 +1,220 @@
import { useState, useRef, useCallback } from 'react'; import { useState, useRef, useCallback, useEffect } from 'react';
import { WebRTCService } from '../services/WebRTC/WebRTCService'; import { WebRTCService, ConnectionState } from '../services/WebRTC/WebRTCService';
import type { RTCMessage } from '~/common'; import useWebSocket, { WebSocketEvents } from './useWebSocket';
import useWebSocket from './useWebSocket';
const SILENCE_THRESHOLD = -50; interface CallError {
const SILENCE_DURATION = 1000; code: string;
message: string;
}
export enum CallState {
IDLE = 'idle',
CONNECTING = 'connecting',
ACTIVE = 'active',
ERROR = 'error',
ENDED = 'ended',
}
interface CallStatus {
callState: CallState;
isConnecting: boolean;
error: CallError | null;
localStream: MediaStream | null;
remoteStream: MediaStream | null;
connectionQuality: 'good' | 'poor' | 'unknown';
}
const INITIAL_STATUS: CallStatus = {
callState: CallState.IDLE,
isConnecting: false,
error: null,
localStream: null,
remoteStream: null,
connectionQuality: 'unknown',
};
const useCall = () => { const useCall = () => {
const { sendMessage: wsMessage, isConnected } = useWebSocket(); const { isConnected, sendMessage, addEventListener } = useWebSocket();
const [isCalling, setIsCalling] = useState(false); const [status, setStatus] = useState<CallStatus>(INITIAL_STATUS);
const [isProcessing, setIsProcessing] = useState(false);
const audioContextRef = useRef<AudioContext | null>(null);
const analyserRef = useRef<AnalyserNode | null>(null);
const audioChunksRef = useRef<Blob[]>([]);
const silenceStartRef = useRef<number | null>(null);
const intervalRef = useRef<number | null>(null);
const webrtcServiceRef = useRef<WebRTCService | null>(null); const webrtcServiceRef = useRef<WebRTCService | null>(null);
const statsIntervalRef = useRef<NodeJS.Timeout>();
const sendAudioChunk = useCallback(() => { const updateStatus = useCallback((updates: Partial<CallStatus>) => {
if (audioChunksRef.current.length === 0) { setStatus((prev) => ({ ...prev, ...updates }));
}, []);
useEffect(() => {
return () => {
if (statsIntervalRef.current) {
clearInterval(statsIntervalRef.current);
}
if (webrtcServiceRef.current) {
webrtcServiceRef.current.close();
}
};
}, []);
const handleConnectionStateChange = useCallback(
(state: ConnectionState) => {
switch (state) {
case ConnectionState.CONNECTED:
updateStatus({
callState: CallState.ACTIVE,
isConnecting: false,
});
break;
case ConnectionState.CONNECTING:
case ConnectionState.RECONNECTING:
updateStatus({
callState: CallState.CONNECTING,
isConnecting: true,
});
break;
case ConnectionState.FAILED:
updateStatus({
callState: CallState.ERROR,
isConnecting: false,
error: {
code: 'CONNECTION_FAILED',
message: 'Connection failed. Please try again.',
},
});
break;
case ConnectionState.CLOSED:
updateStatus({
callState: CallState.ENDED,
isConnecting: false,
localStream: null,
remoteStream: null,
});
break;
}
},
[updateStatus],
);
const startConnectionMonitoring = useCallback(() => {
if (!webrtcServiceRef.current) {
return; return;
} }
const audioBlob = new Blob(audioChunksRef.current, { type: 'audio/webm' }); statsIntervalRef.current = setInterval(async () => {
webrtcServiceRef.current?.sendAudioChunk(audioBlob); const stats = await webrtcServiceRef.current?.getStats();
wsMessage({ type: 'processing-start' }); if (!stats) {
return;
audioChunksRef.current = [];
setIsProcessing(true);
}, [wsMessage]);
const handleRTCMessage = useCallback((message: RTCMessage) => {
if (message.type === 'audio-received') {
setIsProcessing(true);
} }
}, []);
let totalRoundTripTime = 0;
let samplesCount = 0;
stats.forEach((report) => {
if (report.type === 'candidate-pair' && report.currentRoundTripTime) {
totalRoundTripTime += report.currentRoundTripTime;
samplesCount++;
}
});
const averageRTT = samplesCount > 0 ? totalRoundTripTime / samplesCount : 0;
updateStatus({
connectionQuality: averageRTT < 0.3 ? 'good' : 'poor',
});
}, 2000);
}, [updateStatus]);
const startCall = useCallback(async () => { const startCall = useCallback(async () => {
if (!isConnected) { if (!isConnected) {
updateStatus({
callState: CallState.ERROR,
error: {
code: 'NOT_CONNECTED',
message: 'Not connected to server',
},
});
return; return;
} }
webrtcServiceRef.current = new WebRTCService(handleRTCMessage); try {
await webrtcServiceRef.current.initializeCall(); if (webrtcServiceRef.current) {
webrtcServiceRef.current.close();
wsMessage({ type: 'call-start' });
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContextRef.current = new AudioContext();
const source = audioContextRef.current.createMediaStreamSource(stream);
analyserRef.current = audioContextRef.current.createAnalyser();
source.connect(analyserRef.current);
intervalRef.current = window.setInterval(() => {
if (!analyserRef.current || !isCalling) {
return;
} }
const data = new Float32Array(analyserRef.current.frequencyBinCount); updateStatus({
analyserRef.current.getFloatFrequencyData(data); callState: CallState.CONNECTING,
const avg = data.reduce((a, b) => a + b) / data.length; isConnecting: true,
error: null,
});
if (avg < SILENCE_THRESHOLD) { // TODO: Remove debug or make it configurable
if (silenceStartRef.current === null) { webrtcServiceRef.current = new WebRTCService((message) => sendMessage(message), {
silenceStartRef.current = Date.now(); debug: true,
} else if (Date.now() - silenceStartRef.current > SILENCE_DURATION) { });
sendAudioChunk();
silenceStartRef.current = null; webrtcServiceRef.current.on('connectionStateChange', handleConnectionStateChange);
webrtcServiceRef.current.on('remoteStream', (stream: MediaStream) => {
updateStatus({ remoteStream: stream });
});
webrtcServiceRef.current.on('error', (error: string) => {
updateStatus({
callState: CallState.ERROR,
isConnecting: false,
error: {
code: 'WEBRTC_ERROR',
message: error,
},
});
});
await webrtcServiceRef.current.initialize();
startConnectionMonitoring();
} catch (error) {
updateStatus({
callState: CallState.ERROR,
isConnecting: false,
error: {
code: 'INITIALIZATION_FAILED',
message: error instanceof Error ? error.message : 'Failed to start call',
},
});
} }
} else { }, [
silenceStartRef.current = null; isConnected,
} sendMessage,
}, 100); handleConnectionStateChange,
startConnectionMonitoring,
updateStatus,
]);
setIsCalling(true); const hangUp = useCallback(() => {
}, [handleRTCMessage, isConnected, wsMessage, sendAudioChunk, isCalling]); if (webrtcServiceRef.current) {
webrtcServiceRef.current.close();
const hangUp = useCallback(async () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
}
analyserRef.current = null;
audioContextRef.current?.close();
audioContextRef.current = null;
await webrtcServiceRef.current?.endCall();
webrtcServiceRef.current = null; webrtcServiceRef.current = null;
}
if (statsIntervalRef.current) {
clearInterval(statsIntervalRef.current);
}
updateStatus({
...INITIAL_STATUS,
callState: CallState.ENDED,
});
}, [updateStatus]);
setIsCalling(false); useEffect(() => {
setIsProcessing(false); const cleanupFns = [
wsMessage({ type: 'call-ended' }); addEventListener(WebSocketEvents.WEBRTC_ANSWER, (answer: RTCSessionDescriptionInit) => {
}, [wsMessage]); webrtcServiceRef.current?.handleAnswer(answer);
}),
addEventListener(WebSocketEvents.ICE_CANDIDATE, (candidate: RTCIceCandidateInit) => {
webrtcServiceRef.current?.addIceCandidate(candidate);
}),
];
return () => cleanupFns.forEach((fn) => fn());
}, [addEventListener]);
return { return {
isCalling, ...status,
isProcessing,
startCall, startCall,
hangUp, hangUp,
}; };

View file

@ -1,46 +0,0 @@
import { useRef, useCallback } from 'react';
import { WebRTCService } from '../services/WebRTC/WebRTCService';
import type { RTCMessage } from '~/common';
import useWebSocket from './useWebSocket';
const useWebRTC = () => {
const { sendMessage } = useWebSocket();
const webrtcServiceRef = useRef<WebRTCService | null>(null);
const handleRTCMessage = useCallback(
(message: RTCMessage) => {
switch (message.type) {
case 'audio-chunk':
sendMessage({ type: 'processing-start' });
break;
case 'transcription':
case 'llm-response':
case 'tts-chunk':
// TODO: Handle streaming responses
break;
}
},
[sendMessage],
);
const startLocalStream = async () => {
try {
webrtcServiceRef.current = new WebRTCService(handleRTCMessage);
await webrtcServiceRef.current.initializeCall();
sendMessage({ type: 'call-start' });
} catch (error) {
console.error(error);
throw error;
}
};
const stopLocalStream = useCallback(() => {
webrtcServiceRef.current?.endCall();
webrtcServiceRef.current = null;
sendMessage({ type: 'call-ended' });
}, [sendMessage]);
return { startLocalStream, stopLocalStream };
};
export default useWebRTC;

View file

@ -1,62 +1,140 @@
import { useEffect, useRef, useState, useCallback } from 'react'; import { useEffect, useRef, useState } from 'react';
import { useGetWebsocketUrlQuery } from 'librechat-data-provider/react-query'; import { useGetWebsocketUrlQuery } from 'librechat-data-provider/react-query';
import type { MessagePayload } from '~/common';
import { io, Socket } from 'socket.io-client'; import { io, Socket } from 'socket.io-client';
import type { RTCMessage } from '~/common'; import { EventEmitter } from 'events';
const useWebSocket = () => { export const WebSocketEvents = {
const { data } = useGetWebsocketUrlQuery(); CALL_STARTED: 'call-started',
const [isConnected, setIsConnected] = useState(false); CALL_ERROR: 'call-error',
const socketRef = useRef<Socket | null>(null); WEBRTC_ANSWER: 'webrtc-answer',
ICE_CANDIDATE: 'icecandidate',
} as const;
const connect = useCallback(() => { type EventHandler = (...args: unknown[]) => void;
if (!data || !data.url) {
class WebSocketManager extends EventEmitter {
private socket: Socket | null = null;
private reconnectAttempts = 0;
private readonly MAX_RECONNECT_ATTEMPTS = 5;
private isConnected = false;
connect(url: string) {
if (this.socket && this.socket.connected) {
return;
}
this.socket = io(url, {
transports: ['websocket'],
reconnectionAttempts: this.MAX_RECONNECT_ATTEMPTS,
timeout: 10000,
});
this.setupEventHandlers();
}
private setupEventHandlers() {
if (!this.socket) {
return; return;
} }
socketRef.current = io(data.url, { transports: ['websocket'] }); this.socket.on('connect', () => {
this.isConnected = true;
socketRef.current.on('connect', () => { this.reconnectAttempts = 0;
setIsConnected(true); this.emit('connectionChange', true);
}); });
socketRef.current.on('disconnect', () => { this.socket.on('disconnect', (reason) => {
setIsConnected(false); this.isConnected = false;
this.emit('connectionChange', false);
}); });
socketRef.current.on('error', (err) => { this.socket.on('connect_error', (error) => {
console.error('Socket.IO error:', err); this.reconnectAttempts++;
}); this.emit('connectionChange', false);
if (this.reconnectAttempts >= this.MAX_RECONNECT_ATTEMPTS) {
socketRef.current.on('transcription', (msg: RTCMessage) => { this.emit('error', 'Failed to connect after maximum attempts');
// TODO: Handle transcription update this.disconnect();
});
socketRef.current.on('llm-response', (msg: RTCMessage) => {
// TODO: Handle LLM streaming response
});
socketRef.current.on('tts-chunk', (msg: RTCMessage) => {
if (typeof msg.data === 'string') {
const audio = new Audio(`data:audio/mp3;base64,${msg.data}`);
audio.play().catch(console.error);
} }
}); });
}, [data?.url]);
// WebRTC signals
this.socket.on(WebSocketEvents.CALL_STARTED, () => {
this.emit(WebSocketEvents.CALL_STARTED);
});
this.socket.on(WebSocketEvents.WEBRTC_ANSWER, (answer) => {
this.emit(WebSocketEvents.WEBRTC_ANSWER, answer);
});
this.socket.on(WebSocketEvents.ICE_CANDIDATE, (candidate) => {
this.emit(WebSocketEvents.ICE_CANDIDATE, candidate);
});
this.socket.on('error', (error) => {
this.emit('error', error);
});
}
disconnect() {
if (this.socket) {
this.socket.disconnect();
this.socket = null;
}
this.isConnected = false;
}
sendMessage(type: string, payload?: MessagePayload) {
if (!this.socket || !this.socket.connected) {
return false;
}
this.socket.emit(type, payload);
return true;
}
getConnectionState() {
return this.isConnected;
}
}
export const webSocketManager = new WebSocketManager();
const useWebSocket = () => {
const { data: wsConfig } = useGetWebsocketUrlQuery();
const [isConnected, setIsConnected] = useState(false);
const eventHandlersRef = useRef<Record<string, EventHandler>>({});
useEffect(() => { useEffect(() => {
connect(); if (wsConfig?.url && !webSocketManager.getConnectionState()) {
webSocketManager.connect(wsConfig.url);
const handleConnectionChange = (connected: boolean) => setIsConnected(connected);
webSocketManager.on('connectionChange', handleConnectionChange);
webSocketManager.on('error', console.error);
return () => { return () => {
socketRef.current?.disconnect(); webSocketManager.off('connectionChange', handleConnectionChange);
webSocketManager.off('error', console.error);
}; };
}, [connect]);
const sendMessage = useCallback((message: Record<string, unknown>) => {
if (socketRef.current?.connected) {
socketRef.current.emit('message', message);
} }
}, []); }, [wsConfig, wsConfig?.url]);
return { isConnected, sendMessage }; const sendMessage = (message: { type: string; payload?: MessagePayload }) => {
return webSocketManager.sendMessage(message.type, message.payload);
};
const addEventListener = (event: string, handler: EventHandler) => {
eventHandlersRef.current[event] = handler;
webSocketManager.on(event, handler);
return () => {
webSocketManager.off(event, handler);
delete eventHandlersRef.current[event];
};
};
return {
isConnected,
sendMessage,
addEventListener,
};
}; };
export default useWebSocket; export default useWebSocket;

View file

@ -1,55 +1,288 @@
import type { RTCMessage } from '~/common'; import { EventEmitter } from 'events';
export class WebRTCService { import type { MessagePayload } from '~/common';
export enum ConnectionState {
IDLE = 'idle',
CONNECTING = 'connecting',
CONNECTED = 'connected',
RECONNECTING = 'reconnecting',
FAILED = 'failed',
CLOSED = 'closed',
}
export enum MediaState {
INACTIVE = 'inactive',
PENDING = 'pending',
ACTIVE = 'active',
FAILED = 'failed',
}
interface WebRTCConfig {
iceServers?: RTCIceServer[];
maxReconnectAttempts?: number;
connectionTimeout?: number;
debug?: boolean;
}
export class WebRTCService extends EventEmitter {
private peerConnection: RTCPeerConnection | null = null; private peerConnection: RTCPeerConnection | null = null;
private dataChannel: RTCDataChannel | null = null; private localStream: MediaStream | null = null;
private mediaRecorder: MediaRecorder | null = null; private remoteStream: MediaStream | null = null;
private onMessage: (msg: RTCMessage) => void; private reconnectAttempts = 0;
private connectionTimeoutId: NodeJS.Timeout | null = null;
private config: Required<WebRTCConfig>;
private connectionState: ConnectionState = ConnectionState.IDLE;
private mediaState: MediaState = MediaState.INACTIVE;
constructor(onMessage: (msg: RTCMessage) => void) { private readonly DEFAULT_CONFIG: Required<WebRTCConfig> = {
this.onMessage = onMessage; iceServers: [
} {
urls: ['stun:stun.l.google.com:19302', 'stun:stun1.l.google.com:19302'],
async initializeCall() { },
this.peerConnection = new RTCPeerConnection(); ],
this.dataChannel = this.peerConnection.createDataChannel('audio'); maxReconnectAttempts: 3,
connectionTimeout: 15000,
const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); debug: false,
this.mediaRecorder = new MediaRecorder(stream);
this.mediaRecorder.ondataavailable = (e) => {
if (e.data.size > 0 && this.dataChannel?.readyState === 'open') {
e.data.arrayBuffer().then((buffer) => {
this.dataChannel?.send(buffer);
});
}
}; };
this.mediaRecorder.start(100); constructor(
this.setupDataChannel(); private readonly sendMessage: (message: { type: string; payload?: MessagePayload }) => boolean,
config: WebRTCConfig = {},
) {
super();
this.config = { ...this.DEFAULT_CONFIG, ...config };
this.log('WebRTCService initialized with config:', this.config);
} }
private setupDataChannel() { private log(...args: unknown[]) {
if (!this.dataChannel) { if (this.config.debug) {
console.log('[WebRTC]', ...args);
}
}
private setConnectionState(state: ConnectionState) {
this.connectionState = state;
this.emit('connectionStateChange', state);
this.log('Connection state changed to:', state);
}
private setMediaState(state: MediaState) {
this.mediaState = state;
this.emit('mediaStateChange', state);
this.log('Media state changed to:', state);
}
async initialize() {
try {
this.setConnectionState(ConnectionState.CONNECTING);
this.setMediaState(MediaState.PENDING);
this.localStream = await navigator.mediaDevices.getUserMedia({
audio: {
echoCancellation: true,
noiseSuppression: true,
autoGainControl: true,
},
});
this.peerConnection = new RTCPeerConnection({
iceServers: this.config.iceServers,
iceCandidatePoolSize: 10,
bundlePolicy: 'max-bundle',
rtcpMuxPolicy: 'require',
});
this.setupPeerConnectionListeners();
this.localStream.getTracks().forEach((track) => {
if (this.localStream && this.peerConnection) {
this.peerConnection.addTrack(track, this.localStream);
}
});
this.startConnectionTimeout();
await this.createAndSendOffer();
this.setMediaState(MediaState.ACTIVE);
} catch (error) {
this.log('Initialization error:', error);
this.handleError(error);
}
}
private sendSignalingMessage(message: { type: string; payload?: MessagePayload }) {
const sent = this.sendMessage(message);
if (!sent) {
this.handleError(new Error('Failed to send signaling message - WebSocket not connected'));
}
}
private setupPeerConnectionListeners() {
if (!this.peerConnection) {
return; return;
} }
this.dataChannel.onmessage = (event) => { this.peerConnection.ontrack = ({ track, streams }) => {
this.onMessage({ this.log('Received remote track:', track.kind);
type: 'audio-chunk', this.remoteStream = streams[0];
data: event.data, this.emit('remoteStream', this.remoteStream);
};
this.peerConnection.onicecandidate = ({ candidate }) => {
if (candidate) {
this.sendSignalingMessage({
type: 'icecandidate',
payload: candidate.toJSON(),
}); });
}
};
this.peerConnection.onconnectionstatechange = () => {
const state = this.peerConnection?.connectionState;
this.log('Connection state changed:', state);
switch (state) {
case 'connected':
this.setConnectionState(ConnectionState.CONNECTED);
this.clearConnectionTimeout();
this.reconnectAttempts = 0;
break;
case 'failed':
if (this.reconnectAttempts < this.config.maxReconnectAttempts) {
this.attemptReconnection();
} else {
this.handleError(new Error('Connection failed after max reconnection attempts'));
}
break;
case 'disconnected':
this.setConnectionState(ConnectionState.RECONNECTING);
this.attemptReconnection();
break;
case 'closed':
this.setConnectionState(ConnectionState.CLOSED);
break;
}
};
this.peerConnection.oniceconnectionstatechange = () => {
this.log('ICE connection state:', this.peerConnection?.iceConnectionState);
}; };
} }
public async sendAudioChunk(audioBlob: Blob) { private async createAndSendOffer() {
if (this.dataChannel && this.dataChannel.readyState === 'open') { if (!this.peerConnection) {
this.dataChannel.send(await audioBlob.arrayBuffer()); return;
}
try {
const offer = await this.peerConnection.createOffer({
offerToReceiveAudio: true,
});
await this.peerConnection.setLocalDescription(offer);
this.sendSignalingMessage({
type: 'webrtc-offer',
payload: offer,
});
} catch (error) {
this.handleError(error);
} }
} }
async endCall() { public async handleAnswer(answer: RTCSessionDescriptionInit) {
this.mediaRecorder?.stop(); if (!this.peerConnection) {
this.dataChannel?.close(); return;
this.peerConnection?.close(); }
try {
await this.peerConnection.setRemoteDescription(new RTCSessionDescription(answer));
this.log('Remote description set successfully');
} catch (error) {
this.handleError(error);
}
}
public async addIceCandidate(candidate: RTCIceCandidateInit) {
if (!this.peerConnection?.remoteDescription) {
this.log('Delaying ICE candidate addition - no remote description');
return;
}
try {
await this.peerConnection.addIceCandidate(new RTCIceCandidate(candidate));
this.log('ICE candidate added successfully');
} catch (error) {
this.handleError(error);
}
}
private startConnectionTimeout() {
this.clearConnectionTimeout();
this.connectionTimeoutId = setTimeout(() => {
if (this.connectionState !== ConnectionState.CONNECTED) {
this.handleError(new Error('Connection timeout'));
}
}, this.config.connectionTimeout);
}
private clearConnectionTimeout() {
if (this.connectionTimeoutId) {
clearTimeout(this.connectionTimeoutId);
this.connectionTimeoutId = null;
}
}
private async attemptReconnection() {
this.reconnectAttempts++;
this.log(
`Attempting reconnection (${this.reconnectAttempts}/${this.config.maxReconnectAttempts})`,
);
this.setConnectionState(ConnectionState.RECONNECTING);
this.emit('reconnecting', this.reconnectAttempts);
try {
if (this.peerConnection) {
const offer = await this.peerConnection.createOffer({ iceRestart: true });
await this.peerConnection.setLocalDescription(offer);
this.sendSignalingMessage({
type: 'webrtc-offer',
payload: offer,
});
}
} catch (error) {
this.handleError(error);
}
}
private handleError(error: Error | unknown) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred';
this.log('Error:', errorMessage);
this.setConnectionState(ConnectionState.FAILED);
this.emit('error', errorMessage);
this.close();
}
public close() {
this.clearConnectionTimeout();
if (this.localStream) {
this.localStream.getTracks().forEach((track) => track.stop());
this.localStream = null;
}
if (this.peerConnection) {
this.peerConnection.close();
this.peerConnection = null;
}
this.setConnectionState(ConnectionState.CLOSED);
this.setMediaState(MediaState.INACTIVE);
}
public getStats(): Promise<RTCStatsReport> | null {
return this.peerConnection?.getStats() ?? null;
} }
} }