diff --git a/api/server/services/WebSocket/WebSocketServer.js b/api/server/services/WebSocket/WebSocketServer.js index 19b1c4e9f8..b5339377c2 100644 --- a/api/server/services/WebSocket/WebSocketServer.js +++ b/api/server/services/WebSocket/WebSocketServer.js @@ -1,157 +1,205 @@ const { Server } = require('socket.io'); -const { RTCPeerConnection } = require('wrtc'); +const { RTCPeerConnection, RTCIceCandidate } = require('wrtc'); -module.exports.SocketIOService = class { - constructor(httpServer) { - this.io = new Server(httpServer, { path: '/socket.io' }); - this.log('Socket.IO Server initialized'); - this.activeClients = new Map(); - this.iceServers = [ - { urls: 'stun:stun.l.google.com:19302' }, - { urls: 'stun:stun1.l.google.com:19302' }, - ]; - this.setupHandlers(); +class WebRTCConnection { + constructor(socket, config) { + this.socket = socket; + this.config = config; + this.peerConnection = null; + this.audioTransceiver = null; + this.pendingCandidates = []; + this.state = 'idle'; + this.log = config.log || console.log; } - log(msg) { - console.log(`[Socket.IO ${new Date().toISOString()}] ${msg}`); - } + async handleOffer(offer) { + try { + // Create new peer connection if needed + if (!this.peerConnection) { + this.peerConnection = new RTCPeerConnection(this.config.rtcConfig); + this.setupPeerConnectionListeners(); + } - setupHandlers() { - this.io.on('connection', (socket) => { - const clientId = socket.id; - this.activeClients.set(clientId, { - socket, - state: 'idle', - audioBuffer: [], - currentTranscription: '', - isProcessing: false, + // Set the remote description (client's offer) + await this.peerConnection.setRemoteDescription(offer); + + // Set up audio transceiver for two-way audio + this.audioTransceiver = this.peerConnection.addTransceiver('audio', { + direction: 'sendrecv', }); - this.log(`Client connected: ${clientId}`); + // Create and set local description (answer) + const answer = await this.peerConnection.createAnswer(); + await this.peerConnection.setLocalDescription(answer); - socket.on('call-start', () => this.handleCallStart(clientId)); - socket.on('audio-chunk', (data) => this.handleAudioChunk(clientId, data)); - socket.on('processing-start', () => this.processAudioStream(clientId)); - socket.on('audio-received', () => this.confirmAudioReceived(clientId)); - socket.on('call-ended', () => this.handleCallEnd(clientId)); + // Send answer to client + this.socket.emit('webrtc-answer', answer); + + // Process any pending ICE candidates + while (this.pendingCandidates.length) { + const candidate = this.pendingCandidates.shift(); + await this.addIceCandidate(candidate); + } + + this.state = 'connecting'; + } catch (error) { + this.log(`Error handling offer: ${error}`, 'error'); + this.socket.emit('error', { message: 'Failed to process offer' }); + this.cleanup(); + } + } + + setupPeerConnectionListeners() { + if (!this.peerConnection) { + return; + } + + // Handle incoming audio tracks + this.peerConnection.ontrack = ({ track, streams }) => { + this.log(`Received ${track.kind} track from client`); + + // For testing: Echo the audio back after a delay + if (track.kind === 'audio') { + this.handleIncomingAudio(track, streams[0]); + } + + track.onended = () => { + this.log(`${track.kind} track ended`); + }; + }; + + this.peerConnection.onicecandidate = ({ candidate }) => { + if (candidate) { + this.socket.emit('icecandidate', candidate); + } + }; + + this.peerConnection.onconnectionstatechange = () => { + if (!this.peerConnection) { + return; + } + const state = this.peerConnection.connectionState; + this.log(`Connection state changed to ${state}`); + this.state = state; + if (state === 'failed' || state === 'closed') { + this.cleanup(); + } + }; + + this.peerConnection.oniceconnectionstatechange = () => { + if (this.peerConnection) { + this.log(`ICE connection state: ${this.peerConnection.iceConnectionState}`); + } + }; + } + + handleIncomingAudio(inputTrack) { + // For testing: Echo back the input track directly + this.peerConnection.addTrack(inputTrack); + + // Log the track info for debugging + this.log(`Audio track added: ${inputTrack.id}, enabled: ${inputTrack.enabled}`); + } + + async addIceCandidate(candidate) { + try { + if (this.peerConnection?.remoteDescription) { + if (candidate && candidate.candidate) { + await this.peerConnection.addIceCandidate(new RTCIceCandidate(candidate)); + } else { + this.log('Invalid ICE candidate', 'warn'); + } + } else { + this.pendingCandidates.push(candidate); + } + } catch (error) { + this.log(`Error adding ICE candidate: ${error}`, 'error'); + } + } + + cleanup() { + if (this.peerConnection) { + try { + this.peerConnection.close(); + } catch (error) { + this.log(`Error closing peer connection: ${error}`, 'error'); + } + this.peerConnection = null; + } + this.audioTransceiver = null; + this.pendingCandidates = []; + this.state = 'idle'; + } +} + +class SocketIOService { + constructor(httpServer, config = {}) { + this.config = { + rtcConfig: { + iceServers: [ + { + urls: ['stun:stun.l.google.com:19302', 'stun:stun1.l.google.com:19302'], + }, + ], + iceCandidatePoolSize: 10, + bundlePolicy: 'max-bundle', + rtcpMuxPolicy: 'require', + }, + ...config, + }; + + this.io = new Server(httpServer, { + path: '/socket.io', + cors: { + origin: '*', + methods: ['GET', 'POST'], + }, + }); + + this.connections = new Map(); + this.setupSocketHandlers(); + } + + log(message, level = 'info') { + const timestamp = new Date().toISOString(); + console.log(`[WebRTC ${timestamp}] [${level.toUpperCase()}] ${message}`); + } + + setupSocketHandlers() { + this.io.on('connection', (socket) => { + this.log(`Client connected: ${socket.id}`); + + // Create a new WebRTC connection for this socket + const rtcConnection = new WebRTCConnection(socket, { + ...this.config, + log: this.log.bind(this), + }); + this.connections.set(socket.id, rtcConnection); + + socket.on('webrtc-offer', (offer) => { + this.log(`Received WebRTC offer from ${socket.id}`); + rtcConnection.handleOffer(offer); + }); + + socket.on('icecandidate', (candidate) => { + rtcConnection.addIceCandidate(candidate); + }); socket.on('disconnect', () => { - this.handleCallEnd(clientId); - this.activeClients.delete(clientId); - this.log(`Client disconnected: ${clientId}`); - }); - - socket.on('error', (error) => { - this.log(`Error for client ${clientId}: ${error.message}`); - this.handleCallEnd(clientId); + this.log(`Client disconnected: ${socket.id}`); + rtcConnection.cleanup(); + this.connections.delete(socket.id); }); }); } - async handleCallStart(clientId) { - const client = this.activeClients.get(clientId); - if (!client) { - return; - } - - try { - client.state = 'active'; - client.audioBuffer = []; - client.currentTranscription = ''; - client.isProcessing = false; - - const peerConnection = new RTCPeerConnection({ - iceServers: this.iceServers, - sdpSemantics: 'unified-plan', - }); - - client.peerConnection = peerConnection; - client.dataChannel = peerConnection.createDataChannel('audio', { - ordered: true, - maxRetransmits: 3, - }); - - client.dataChannel.onopen = () => this.log(`Data channel opened for ${clientId}`); - client.dataChannel.onmessage = async (event) => { - await this.handleAudioChunk(clientId, event.data); - }; - - peerConnection.onicecandidate = (event) => { - if (event.candidate) { - client.socket.emit('ice-candidate', { candidate: event.candidate }); - } - }; - - peerConnection.onnegotiationneeded = async () => { - try { - const offer = await peerConnection.createOffer(); - await peerConnection.setLocalDescription(offer); - client.socket.emit('webrtc-offer', { sdp: peerConnection.localDescription }); - } catch (error) { - this.log(`Negotiation failed for ${clientId}: ${error}`); - } - }; - - this.log(`Call started for client ${clientId}`); - } catch (error) { - this.log(`Error starting call for ${clientId}: ${error.message}`); - this.handleCallEnd(clientId); + shutdown() { + for (const connection of this.connections.values()) { + connection.cleanup(); } + this.connections.clear(); + this.io.close(); } +} - async handleAudioChunk(clientId, data) { - const client = this.activeClients.get(clientId); - if (!client || client.state !== 'active') { - return; - } - - client.audioBuffer.push(data); - client.socket.emit('audio-received'); - } - - async processAudioStream(clientId) { - const client = this.activeClients.get(clientId); - if (!client || client.state !== 'active' || client.isProcessing) { - return; - } - - client.isProcessing = true; - - try { - // Process transcription - client.socket.emit('transcription', { data: 'Processing audio...' }); - - // Stream LLM response - client.socket.emit('llm-response', { data: 'Processing response...' }); - - // Stream TTS chunks - client.socket.emit('tts-chunk', { data: 'audio_data_here' }); - } catch (error) { - this.log(`Processing error for client ${clientId}: ${error.message}`); - } finally { - client.isProcessing = false; - client.audioBuffer = []; - } - } - - confirmAudioReceived(clientId) { - const client = this.activeClients.get(clientId); - if (!client) { - return; - } - - client.socket.emit('audio-received', { data: null }); - } - - handleCallEnd(clientId) { - const client = this.activeClients.get(clientId); - if (!client) { - return; - } - - client.state = 'idle'; - client.audioBuffer = []; - client.currentTranscription = ''; - } -}; +module.exports = { SocketIOService }; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 97bd105644..6474f9b352 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -55,8 +55,23 @@ export interface RTCMessage { | 'transcription' | 'llm-response' | 'tts-chunk' - | 'call-ended'; - data?: string | ArrayBuffer | null; + | 'call-ended' + | 'webrtc-answer' + | 'icecandidate'; + payload?: RTCSessionDescriptionInit | RTCIceCandidateInit; +} + +export type MessagePayload = + | RTCSessionDescriptionInit + | RTCIceCandidateInit + | Record; + +export enum CallState { + IDLE = 'idle', + CONNECTING = 'connecting', + ACTIVE = 'active', + ERROR = 'error', + ENDED = 'ended', } export type AssistantListItem = { diff --git a/client/src/components/Chat/Input/Call.tsx b/client/src/components/Chat/Input/Call.tsx index 85aedbff1f..15021a6801 100644 --- a/client/src/components/Chat/Input/Call.tsx +++ b/client/src/components/Chat/Input/Call.tsx @@ -1,24 +1,60 @@ -import React from 'react'; +import React, { useEffect, useRef } from 'react'; import { useRecoilState } from 'recoil'; -import { Phone, PhoneOff } from 'lucide-react'; +import { + Phone, + PhoneOff, + AlertCircle, + Mic, + MicOff, + Volume2, + VolumeX, + Activity, +} from 'lucide-react'; import { OGDialog, OGDialogContent, Button } from '~/components'; import { useWebSocket, useCall } from '~/hooks'; +import { CallState } from '~/common'; import store from '~/store'; export const Call: React.FC = () => { - const { isConnected, sendMessage } = useWebSocket(); - const { isCalling, isProcessing, startCall, hangUp } = useCall(); + const { isConnected } = useWebSocket(); + const { + callState, + error, + startCall, + hangUp, + isConnecting, + localStream, + remoteStream, + connectionQuality, + } = useCall(); const [open, setOpen] = useRecoilState(store.callDialogOpen(0)); - const [eventLog, setEventLog] = React.useState([]); + const [isMuted, setIsMuted] = React.useState(false); + const [isAudioEnabled, setIsAudioEnabled] = React.useState(true); + + const remoteAudioRef = useRef(null); const logEvent = (message: string) => { console.log(message); - setEventLog((prev) => [...prev, message]); + setEventLog((prev) => [...prev, `${new Date().toISOString()}: ${message}`]); }; - React.useEffect(() => { + useEffect(() => { + if (remoteAudioRef.current && remoteStream) { + remoteAudioRef.current.srcObject = remoteStream; + } + }, [remoteStream]); + + useEffect(() => { + if (localStream) { + localStream.getAudioTracks().forEach((track) => { + track.enabled = !isMuted; + }); + } + }, [localStream, isMuted]); + + useEffect(() => { if (isConnected) { logEvent('Connected to server.'); } else { @@ -26,15 +62,15 @@ export const Call: React.FC = () => { } }, [isConnected]); - React.useEffect(() => { - if (isCalling) { - logEvent('Call started.'); - } else if (isProcessing) { - logEvent('Processing audio...'); - } else { - logEvent('Call ended.'); + useEffect(() => { + if (error) { + logEvent(`Error: ${error.message} (${error.code})`); } - }, [isCalling, isProcessing]); + }, [error]); + + useEffect(() => { + logEvent(`Call state changed to: ${callState}`); + }, [callState]); const handleStartCall = () => { logEvent('Attempting to start call...'); @@ -46,51 +82,127 @@ export const Call: React.FC = () => { hangUp(); }; + const toggleMute = () => { + setIsMuted((prev) => !prev); + logEvent(`Microphone ${isMuted ? 'unmuted' : 'muted'}`); + }; + + const toggleAudio = () => { + setIsAudioEnabled((prev) => !prev); + if (remoteAudioRef.current) { + remoteAudioRef.current.muted = !isAudioEnabled; + } + logEvent(`Speaker ${isAudioEnabled ? 'disabled' : 'enabled'}`); + }; + + const isActive = callState === CallState.ACTIVE; + const isError = callState === CallState.ERROR; + return ( - +
-
+ {/* Connection Status */} +
- - {isConnected ? 'Connected' : 'Disconnected'} - + className={`flex items-center gap-2 rounded-full px-4 py-2 ${ + isConnected ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700' + }`} + > +
+ + {isConnected ? 'Connected' : 'Disconnected'} + +
+ + {isActive && ( +
+ + {connectionQuality} Quality +
+ )}
- {isCalling ? ( - - ) : ( - + {/* Error Display */} + {error && ( +
+ + {error.message} +
)} - {/* Debugging Information */} + {/* Call Controls */} +
+ {isActive && ( + <> + + + + + )} + + {isActive ? ( + + ) : ( + + )} +
+ + {/* Event Log */}

Event Log

-
    - {eventLog.map((log, index) => ( -
  • {log}
  • - ))} -
+
+
    + {eventLog.map((log, index) => ( +
  • + {log} +
  • + ))} +
+
+ + {/* Hidden Audio Element */} +
diff --git a/client/src/hooks/index.ts b/client/src/hooks/index.ts index cdfbc92b8f..f7c98a1b04 100644 --- a/client/src/hooks/index.ts +++ b/client/src/hooks/index.ts @@ -20,7 +20,6 @@ export * from './ScreenshotContext'; export * from './ApiErrorBoundaryContext'; export { default as useCall } from './useCall'; export { default as useToast } from './useToast'; -export { default as useWebRTC } from './useWebRTC'; export { default as useTimeout } from './useTimeout'; export { default as useNewConvo } from './useNewConvo'; export { default as useLocalize } from './useLocalize'; diff --git a/client/src/hooks/useCall.ts b/client/src/hooks/useCall.ts index 7afcfc8263..713c79be2a 100644 --- a/client/src/hooks/useCall.ts +++ b/client/src/hooks/useCall.ts @@ -1,101 +1,220 @@ -import { useState, useRef, useCallback } from 'react'; -import { WebRTCService } from '../services/WebRTC/WebRTCService'; -import type { RTCMessage } from '~/common'; -import useWebSocket from './useWebSocket'; +import { useState, useRef, useCallback, useEffect } from 'react'; +import { WebRTCService, ConnectionState } from '../services/WebRTC/WebRTCService'; +import useWebSocket, { WebSocketEvents } from './useWebSocket'; -const SILENCE_THRESHOLD = -50; -const SILENCE_DURATION = 1000; +interface CallError { + code: string; + message: string; +} + +export enum CallState { + IDLE = 'idle', + CONNECTING = 'connecting', + ACTIVE = 'active', + ERROR = 'error', + ENDED = 'ended', +} + +interface CallStatus { + callState: CallState; + isConnecting: boolean; + error: CallError | null; + localStream: MediaStream | null; + remoteStream: MediaStream | null; + connectionQuality: 'good' | 'poor' | 'unknown'; +} + +const INITIAL_STATUS: CallStatus = { + callState: CallState.IDLE, + isConnecting: false, + error: null, + localStream: null, + remoteStream: null, + connectionQuality: 'unknown', +}; const useCall = () => { - const { sendMessage: wsMessage, isConnected } = useWebSocket(); - const [isCalling, setIsCalling] = useState(false); - const [isProcessing, setIsProcessing] = useState(false); - const audioContextRef = useRef(null); - const analyserRef = useRef(null); - const audioChunksRef = useRef([]); - const silenceStartRef = useRef(null); - const intervalRef = useRef(null); + const { isConnected, sendMessage, addEventListener } = useWebSocket(); + const [status, setStatus] = useState(INITIAL_STATUS); const webrtcServiceRef = useRef(null); + const statsIntervalRef = useRef(); - const sendAudioChunk = useCallback(() => { - if (audioChunksRef.current.length === 0) { - return; - } - - const audioBlob = new Blob(audioChunksRef.current, { type: 'audio/webm' }); - webrtcServiceRef.current?.sendAudioChunk(audioBlob); - wsMessage({ type: 'processing-start' }); - - audioChunksRef.current = []; - setIsProcessing(true); - }, [wsMessage]); - - const handleRTCMessage = useCallback((message: RTCMessage) => { - if (message.type === 'audio-received') { - setIsProcessing(true); - } + const updateStatus = useCallback((updates: Partial) => { + setStatus((prev) => ({ ...prev, ...updates })); }, []); - const startCall = useCallback(async () => { - if (!isConnected) { + useEffect(() => { + return () => { + if (statsIntervalRef.current) { + clearInterval(statsIntervalRef.current); + } + if (webrtcServiceRef.current) { + webrtcServiceRef.current.close(); + } + }; + }, []); + + const handleConnectionStateChange = useCallback( + (state: ConnectionState) => { + switch (state) { + case ConnectionState.CONNECTED: + updateStatus({ + callState: CallState.ACTIVE, + isConnecting: false, + }); + break; + case ConnectionState.CONNECTING: + case ConnectionState.RECONNECTING: + updateStatus({ + callState: CallState.CONNECTING, + isConnecting: true, + }); + break; + case ConnectionState.FAILED: + updateStatus({ + callState: CallState.ERROR, + isConnecting: false, + error: { + code: 'CONNECTION_FAILED', + message: 'Connection failed. Please try again.', + }, + }); + break; + case ConnectionState.CLOSED: + updateStatus({ + callState: CallState.ENDED, + isConnecting: false, + localStream: null, + remoteStream: null, + }); + break; + } + }, + [updateStatus], + ); + + const startConnectionMonitoring = useCallback(() => { + if (!webrtcServiceRef.current) { return; } - webrtcServiceRef.current = new WebRTCService(handleRTCMessage); - await webrtcServiceRef.current.initializeCall(); - - wsMessage({ type: 'call-start' }); - - const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); - audioContextRef.current = new AudioContext(); - const source = audioContextRef.current.createMediaStreamSource(stream); - analyserRef.current = audioContextRef.current.createAnalyser(); - source.connect(analyserRef.current); - - intervalRef.current = window.setInterval(() => { - if (!analyserRef.current || !isCalling) { + statsIntervalRef.current = setInterval(async () => { + const stats = await webrtcServiceRef.current?.getStats(); + if (!stats) { return; } - const data = new Float32Array(analyserRef.current.frequencyBinCount); - analyserRef.current.getFloatFrequencyData(data); - const avg = data.reduce((a, b) => a + b) / data.length; + let totalRoundTripTime = 0; + let samplesCount = 0; - if (avg < SILENCE_THRESHOLD) { - if (silenceStartRef.current === null) { - silenceStartRef.current = Date.now(); - } else if (Date.now() - silenceStartRef.current > SILENCE_DURATION) { - sendAudioChunk(); - silenceStartRef.current = null; + stats.forEach((report) => { + if (report.type === 'candidate-pair' && report.currentRoundTripTime) { + totalRoundTripTime += report.currentRoundTripTime; + samplesCount++; } - } else { - silenceStartRef.current = null; - } - }, 100); + }); - setIsCalling(true); - }, [handleRTCMessage, isConnected, wsMessage, sendAudioChunk, isCalling]); + const averageRTT = samplesCount > 0 ? totalRoundTripTime / samplesCount : 0; + updateStatus({ + connectionQuality: averageRTT < 0.3 ? 'good' : 'poor', + }); + }, 2000); + }, [updateStatus]); - const hangUp = useCallback(async () => { - if (intervalRef.current) { - clearInterval(intervalRef.current); + const startCall = useCallback(async () => { + if (!isConnected) { + updateStatus({ + callState: CallState.ERROR, + error: { + code: 'NOT_CONNECTED', + message: 'Not connected to server', + }, + }); + return; } - analyserRef.current = null; - audioContextRef.current?.close(); - audioContextRef.current = null; + try { + if (webrtcServiceRef.current) { + webrtcServiceRef.current.close(); + } - await webrtcServiceRef.current?.endCall(); - webrtcServiceRef.current = null; + updateStatus({ + callState: CallState.CONNECTING, + isConnecting: true, + error: null, + }); - setIsCalling(false); - setIsProcessing(false); - wsMessage({ type: 'call-ended' }); - }, [wsMessage]); + // TODO: Remove debug or make it configurable + webrtcServiceRef.current = new WebRTCService((message) => sendMessage(message), { + debug: true, + }); + + webrtcServiceRef.current.on('connectionStateChange', handleConnectionStateChange); + + webrtcServiceRef.current.on('remoteStream', (stream: MediaStream) => { + updateStatus({ remoteStream: stream }); + }); + + webrtcServiceRef.current.on('error', (error: string) => { + updateStatus({ + callState: CallState.ERROR, + isConnecting: false, + error: { + code: 'WEBRTC_ERROR', + message: error, + }, + }); + }); + + await webrtcServiceRef.current.initialize(); + startConnectionMonitoring(); + } catch (error) { + updateStatus({ + callState: CallState.ERROR, + isConnecting: false, + error: { + code: 'INITIALIZATION_FAILED', + message: error instanceof Error ? error.message : 'Failed to start call', + }, + }); + } + }, [ + isConnected, + sendMessage, + handleConnectionStateChange, + startConnectionMonitoring, + updateStatus, + ]); + + const hangUp = useCallback(() => { + if (webrtcServiceRef.current) { + webrtcServiceRef.current.close(); + webrtcServiceRef.current = null; + } + if (statsIntervalRef.current) { + clearInterval(statsIntervalRef.current); + } + updateStatus({ + ...INITIAL_STATUS, + callState: CallState.ENDED, + }); + }, [updateStatus]); + + useEffect(() => { + const cleanupFns = [ + addEventListener(WebSocketEvents.WEBRTC_ANSWER, (answer: RTCSessionDescriptionInit) => { + webrtcServiceRef.current?.handleAnswer(answer); + }), + addEventListener(WebSocketEvents.ICE_CANDIDATE, (candidate: RTCIceCandidateInit) => { + webrtcServiceRef.current?.addIceCandidate(candidate); + }), + ]; + + return () => cleanupFns.forEach((fn) => fn()); + }, [addEventListener]); return { - isCalling, - isProcessing, + ...status, startCall, hangUp, }; diff --git a/client/src/hooks/useWebRTC.ts b/client/src/hooks/useWebRTC.ts deleted file mode 100644 index 0e951a531b..0000000000 --- a/client/src/hooks/useWebRTC.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { useRef, useCallback } from 'react'; -import { WebRTCService } from '../services/WebRTC/WebRTCService'; -import type { RTCMessage } from '~/common'; -import useWebSocket from './useWebSocket'; - -const useWebRTC = () => { - const { sendMessage } = useWebSocket(); - const webrtcServiceRef = useRef(null); - - const handleRTCMessage = useCallback( - (message: RTCMessage) => { - switch (message.type) { - case 'audio-chunk': - sendMessage({ type: 'processing-start' }); - break; - case 'transcription': - case 'llm-response': - case 'tts-chunk': - // TODO: Handle streaming responses - break; - } - }, - [sendMessage], - ); - - const startLocalStream = async () => { - try { - webrtcServiceRef.current = new WebRTCService(handleRTCMessage); - await webrtcServiceRef.current.initializeCall(); - sendMessage({ type: 'call-start' }); - } catch (error) { - console.error(error); - throw error; - } - }; - - const stopLocalStream = useCallback(() => { - webrtcServiceRef.current?.endCall(); - webrtcServiceRef.current = null; - sendMessage({ type: 'call-ended' }); - }, [sendMessage]); - - return { startLocalStream, stopLocalStream }; -}; - -export default useWebRTC; diff --git a/client/src/hooks/useWebSocket.ts b/client/src/hooks/useWebSocket.ts index 804a8506c7..cda4d383ec 100644 --- a/client/src/hooks/useWebSocket.ts +++ b/client/src/hooks/useWebSocket.ts @@ -1,62 +1,140 @@ -import { useEffect, useRef, useState, useCallback } from 'react'; +import { useEffect, useRef, useState } from 'react'; import { useGetWebsocketUrlQuery } from 'librechat-data-provider/react-query'; +import type { MessagePayload } from '~/common'; import { io, Socket } from 'socket.io-client'; -import type { RTCMessage } from '~/common'; +import { EventEmitter } from 'events'; -const useWebSocket = () => { - const { data } = useGetWebsocketUrlQuery(); - const [isConnected, setIsConnected] = useState(false); - const socketRef = useRef(null); +export const WebSocketEvents = { + CALL_STARTED: 'call-started', + CALL_ERROR: 'call-error', + WEBRTC_ANSWER: 'webrtc-answer', + ICE_CANDIDATE: 'icecandidate', +} as const; - const connect = useCallback(() => { - if (!data || !data.url) { +type EventHandler = (...args: unknown[]) => void; + +class WebSocketManager extends EventEmitter { + private socket: Socket | null = null; + private reconnectAttempts = 0; + private readonly MAX_RECONNECT_ATTEMPTS = 5; + private isConnected = false; + + connect(url: string) { + if (this.socket && this.socket.connected) { + return; + } + this.socket = io(url, { + transports: ['websocket'], + reconnectionAttempts: this.MAX_RECONNECT_ATTEMPTS, + timeout: 10000, + }); + this.setupEventHandlers(); + } + + private setupEventHandlers() { + if (!this.socket) { return; } - socketRef.current = io(data.url, { transports: ['websocket'] }); - - socketRef.current.on('connect', () => { - setIsConnected(true); + this.socket.on('connect', () => { + this.isConnected = true; + this.reconnectAttempts = 0; + this.emit('connectionChange', true); }); - socketRef.current.on('disconnect', () => { - setIsConnected(false); + this.socket.on('disconnect', (reason) => { + this.isConnected = false; + this.emit('connectionChange', false); }); - socketRef.current.on('error', (err) => { - console.error('Socket.IO error:', err); - }); - - socketRef.current.on('transcription', (msg: RTCMessage) => { - // TODO: Handle transcription update - }); - - socketRef.current.on('llm-response', (msg: RTCMessage) => { - // TODO: Handle LLM streaming response - }); - - socketRef.current.on('tts-chunk', (msg: RTCMessage) => { - if (typeof msg.data === 'string') { - const audio = new Audio(`data:audio/mp3;base64,${msg.data}`); - audio.play().catch(console.error); + this.socket.on('connect_error', (error) => { + this.reconnectAttempts++; + this.emit('connectionChange', false); + if (this.reconnectAttempts >= this.MAX_RECONNECT_ATTEMPTS) { + this.emit('error', 'Failed to connect after maximum attempts'); + this.disconnect(); } }); - }, [data?.url]); + + // WebRTC signals + this.socket.on(WebSocketEvents.CALL_STARTED, () => { + this.emit(WebSocketEvents.CALL_STARTED); + }); + + this.socket.on(WebSocketEvents.WEBRTC_ANSWER, (answer) => { + this.emit(WebSocketEvents.WEBRTC_ANSWER, answer); + }); + + this.socket.on(WebSocketEvents.ICE_CANDIDATE, (candidate) => { + this.emit(WebSocketEvents.ICE_CANDIDATE, candidate); + }); + + this.socket.on('error', (error) => { + this.emit('error', error); + }); + } + + disconnect() { + if (this.socket) { + this.socket.disconnect(); + this.socket = null; + } + this.isConnected = false; + } + + sendMessage(type: string, payload?: MessagePayload) { + if (!this.socket || !this.socket.connected) { + return false; + } + this.socket.emit(type, payload); + return true; + } + + getConnectionState() { + return this.isConnected; + } +} + +export const webSocketManager = new WebSocketManager(); + +const useWebSocket = () => { + const { data: wsConfig } = useGetWebsocketUrlQuery(); + const [isConnected, setIsConnected] = useState(false); + const eventHandlersRef = useRef>({}); useEffect(() => { - connect(); - return () => { - socketRef.current?.disconnect(); - }; - }, [connect]); + if (wsConfig?.url && !webSocketManager.getConnectionState()) { + webSocketManager.connect(wsConfig.url); - const sendMessage = useCallback((message: Record) => { - if (socketRef.current?.connected) { - socketRef.current.emit('message', message); + const handleConnectionChange = (connected: boolean) => setIsConnected(connected); + webSocketManager.on('connectionChange', handleConnectionChange); + webSocketManager.on('error', console.error); + + return () => { + webSocketManager.off('connectionChange', handleConnectionChange); + webSocketManager.off('error', console.error); + }; } - }, []); + }, [wsConfig, wsConfig?.url]); - return { isConnected, sendMessage }; + const sendMessage = (message: { type: string; payload?: MessagePayload }) => { + return webSocketManager.sendMessage(message.type, message.payload); + }; + + const addEventListener = (event: string, handler: EventHandler) => { + eventHandlersRef.current[event] = handler; + webSocketManager.on(event, handler); + return () => { + webSocketManager.off(event, handler); + delete eventHandlersRef.current[event]; + }; + }; + + return { + isConnected, + sendMessage, + addEventListener, + }; }; export default useWebSocket; diff --git a/client/src/services/WebRTC/WebRTCService.ts b/client/src/services/WebRTC/WebRTCService.ts index 9f0564d078..42e6a294d9 100644 --- a/client/src/services/WebRTC/WebRTCService.ts +++ b/client/src/services/WebRTC/WebRTCService.ts @@ -1,55 +1,288 @@ -import type { RTCMessage } from '~/common'; -export class WebRTCService { - private peerConnection: RTCPeerConnection | null = null; - private dataChannel: RTCDataChannel | null = null; - private mediaRecorder: MediaRecorder | null = null; - private onMessage: (msg: RTCMessage) => void; +import { EventEmitter } from 'events'; +import type { MessagePayload } from '~/common'; - constructor(onMessage: (msg: RTCMessage) => void) { - this.onMessage = onMessage; +export enum ConnectionState { + IDLE = 'idle', + CONNECTING = 'connecting', + CONNECTED = 'connected', + RECONNECTING = 'reconnecting', + FAILED = 'failed', + CLOSED = 'closed', +} + +export enum MediaState { + INACTIVE = 'inactive', + PENDING = 'pending', + ACTIVE = 'active', + FAILED = 'failed', +} + +interface WebRTCConfig { + iceServers?: RTCIceServer[]; + maxReconnectAttempts?: number; + connectionTimeout?: number; + debug?: boolean; +} + +export class WebRTCService extends EventEmitter { + private peerConnection: RTCPeerConnection | null = null; + private localStream: MediaStream | null = null; + private remoteStream: MediaStream | null = null; + private reconnectAttempts = 0; + private connectionTimeoutId: NodeJS.Timeout | null = null; + private config: Required; + private connectionState: ConnectionState = ConnectionState.IDLE; + private mediaState: MediaState = MediaState.INACTIVE; + + private readonly DEFAULT_CONFIG: Required = { + iceServers: [ + { + urls: ['stun:stun.l.google.com:19302', 'stun:stun1.l.google.com:19302'], + }, + ], + maxReconnectAttempts: 3, + connectionTimeout: 15000, + debug: false, + }; + + constructor( + private readonly sendMessage: (message: { type: string; payload?: MessagePayload }) => boolean, + config: WebRTCConfig = {}, + ) { + super(); + this.config = { ...this.DEFAULT_CONFIG, ...config }; + this.log('WebRTCService initialized with config:', this.config); } - async initializeCall() { - this.peerConnection = new RTCPeerConnection(); - this.dataChannel = this.peerConnection.createDataChannel('audio'); + private log(...args: unknown[]) { + if (this.config.debug) { + console.log('[WebRTC]', ...args); + } + } - const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); - this.mediaRecorder = new MediaRecorder(stream); + private setConnectionState(state: ConnectionState) { + this.connectionState = state; + this.emit('connectionStateChange', state); + this.log('Connection state changed to:', state); + } - this.mediaRecorder.ondataavailable = (e) => { - if (e.data.size > 0 && this.dataChannel?.readyState === 'open') { - e.data.arrayBuffer().then((buffer) => { - this.dataChannel?.send(buffer); + private setMediaState(state: MediaState) { + this.mediaState = state; + this.emit('mediaStateChange', state); + this.log('Media state changed to:', state); + } + + async initialize() { + try { + this.setConnectionState(ConnectionState.CONNECTING); + this.setMediaState(MediaState.PENDING); + + this.localStream = await navigator.mediaDevices.getUserMedia({ + audio: { + echoCancellation: true, + noiseSuppression: true, + autoGainControl: true, + }, + }); + + this.peerConnection = new RTCPeerConnection({ + iceServers: this.config.iceServers, + iceCandidatePoolSize: 10, + bundlePolicy: 'max-bundle', + rtcpMuxPolicy: 'require', + }); + + this.setupPeerConnectionListeners(); + + this.localStream.getTracks().forEach((track) => { + if (this.localStream && this.peerConnection) { + this.peerConnection.addTrack(track, this.localStream); + } + }); + + this.startConnectionTimeout(); + + await this.createAndSendOffer(); + + this.setMediaState(MediaState.ACTIVE); + } catch (error) { + this.log('Initialization error:', error); + this.handleError(error); + } + } + + private sendSignalingMessage(message: { type: string; payload?: MessagePayload }) { + const sent = this.sendMessage(message); + if (!sent) { + this.handleError(new Error('Failed to send signaling message - WebSocket not connected')); + } + } + + private setupPeerConnectionListeners() { + if (!this.peerConnection) { + return; + } + + this.peerConnection.ontrack = ({ track, streams }) => { + this.log('Received remote track:', track.kind); + this.remoteStream = streams[0]; + this.emit('remoteStream', this.remoteStream); + }; + + this.peerConnection.onicecandidate = ({ candidate }) => { + if (candidate) { + this.sendSignalingMessage({ + type: 'icecandidate', + payload: candidate.toJSON(), }); } }; - this.mediaRecorder.start(100); - this.setupDataChannel(); - } + this.peerConnection.onconnectionstatechange = () => { + const state = this.peerConnection?.connectionState; + this.log('Connection state changed:', state); - private setupDataChannel() { - if (!this.dataChannel) { - return; - } + switch (state) { + case 'connected': + this.setConnectionState(ConnectionState.CONNECTED); + this.clearConnectionTimeout(); + this.reconnectAttempts = 0; + break; + case 'failed': + if (this.reconnectAttempts < this.config.maxReconnectAttempts) { + this.attemptReconnection(); + } else { + this.handleError(new Error('Connection failed after max reconnection attempts')); + } + break; + case 'disconnected': + this.setConnectionState(ConnectionState.RECONNECTING); + this.attemptReconnection(); + break; + case 'closed': + this.setConnectionState(ConnectionState.CLOSED); + break; + } + }; - this.dataChannel.onmessage = (event) => { - this.onMessage({ - type: 'audio-chunk', - data: event.data, - }); + this.peerConnection.oniceconnectionstatechange = () => { + this.log('ICE connection state:', this.peerConnection?.iceConnectionState); }; } - public async sendAudioChunk(audioBlob: Blob) { - if (this.dataChannel && this.dataChannel.readyState === 'open') { - this.dataChannel.send(await audioBlob.arrayBuffer()); + private async createAndSendOffer() { + if (!this.peerConnection) { + return; + } + + try { + const offer = await this.peerConnection.createOffer({ + offerToReceiveAudio: true, + }); + + await this.peerConnection.setLocalDescription(offer); + + this.sendSignalingMessage({ + type: 'webrtc-offer', + payload: offer, + }); + } catch (error) { + this.handleError(error); } } - async endCall() { - this.mediaRecorder?.stop(); - this.dataChannel?.close(); - this.peerConnection?.close(); + public async handleAnswer(answer: RTCSessionDescriptionInit) { + if (!this.peerConnection) { + return; + } + + try { + await this.peerConnection.setRemoteDescription(new RTCSessionDescription(answer)); + this.log('Remote description set successfully'); + } catch (error) { + this.handleError(error); + } + } + + public async addIceCandidate(candidate: RTCIceCandidateInit) { + if (!this.peerConnection?.remoteDescription) { + this.log('Delaying ICE candidate addition - no remote description'); + return; + } + + try { + await this.peerConnection.addIceCandidate(new RTCIceCandidate(candidate)); + this.log('ICE candidate added successfully'); + } catch (error) { + this.handleError(error); + } + } + + private startConnectionTimeout() { + this.clearConnectionTimeout(); + this.connectionTimeoutId = setTimeout(() => { + if (this.connectionState !== ConnectionState.CONNECTED) { + this.handleError(new Error('Connection timeout')); + } + }, this.config.connectionTimeout); + } + + private clearConnectionTimeout() { + if (this.connectionTimeoutId) { + clearTimeout(this.connectionTimeoutId); + this.connectionTimeoutId = null; + } + } + + private async attemptReconnection() { + this.reconnectAttempts++; + this.log( + `Attempting reconnection (${this.reconnectAttempts}/${this.config.maxReconnectAttempts})`, + ); + + this.setConnectionState(ConnectionState.RECONNECTING); + this.emit('reconnecting', this.reconnectAttempts); + + try { + if (this.peerConnection) { + const offer = await this.peerConnection.createOffer({ iceRestart: true }); + await this.peerConnection.setLocalDescription(offer); + this.sendSignalingMessage({ + type: 'webrtc-offer', + payload: offer, + }); + } + } catch (error) { + this.handleError(error); + } + } + + private handleError(error: Error | unknown) { + const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'; + this.log('Error:', errorMessage); + this.setConnectionState(ConnectionState.FAILED); + this.emit('error', errorMessage); + this.close(); + } + + public close() { + this.clearConnectionTimeout(); + + if (this.localStream) { + this.localStream.getTracks().forEach((track) => track.stop()); + this.localStream = null; + } + + if (this.peerConnection) { + this.peerConnection.close(); + this.peerConnection = null; + } + + this.setConnectionState(ConnectionState.CLOSED); + this.setMediaState(MediaState.INACTIVE); + } + + public getStats(): Promise | null { + return this.peerConnection?.getStats() ?? null; } }