From 2eda62cf6788302cad63b2b8216de868b5a424e9 Mon Sep 17 00:00:00 2001 From: Marco Beretta <81851188+berry-13@users.noreply.github.com> Date: Sat, 5 Apr 2025 10:37:53 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20implement=20AudioSocketModu?= =?UTF-8?q?le=20and=20WebRTCHandler=20for=20audio=20streaming;=20refactor?= =?UTF-8?q?=20SocketIOService=20to=20support=20module-based=20event=20hand?= =?UTF-8?q?ling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/server/index.js | 22 +- .../services/Files/Audio/AudioSocketModule.js | 40 ++++ .../services/Files/Audio/WebRTCHandler.js | 179 +++++++++++++++ .../services/WebSocket/WebSocketServer.js | 207 +++++------------- 4 files changed, 292 insertions(+), 156 deletions(-) create mode 100644 api/server/services/Files/Audio/AudioSocketModule.js create mode 100644 api/server/services/Files/Audio/WebRTCHandler.js diff --git a/api/server/index.js b/api/server/index.js index 1084b0b9dc..a087ba1ba5 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -15,6 +15,7 @@ const { connectDb, indexSync } = require('~/lib/db'); const { isEnabled } = require('~/server/utils'); const { ldapLogin } = require('~/strategies'); const { logger } = require('~/config'); +const { AudioSocketModule } = require('./services/Files/Audio/AudioSocketModule'); const { SocketIOService } = require('./services/WebSocket/WebSocketServer'); const validateImageRequest = require('./middleware/validateImageRequest'); const errorController = require('./controllers/ErrorController'); @@ -30,6 +31,9 @@ const port = Number(PORT) || 3080; const host = HOST || 'localhost'; const trusted_proxy = Number(TRUST_PROXY) || 1; /* trust first proxy by default */ +let socketIOService; +let audioModule; + const startServer = async () => { if (typeof Bun !== 'undefined') { axios.defaults.headers.common['Accept-Encoding'] = 'gzip'; @@ -49,7 +53,10 @@ const startServer = async () => { }), ); - new SocketIOService(server); + socketIOService = new SocketIOService(server); + audioModule = new AudioSocketModule(socketIOService); + + logger.info('WebSocket server and Audio module initialized'); await AppService(app); @@ -156,6 +163,19 @@ const startServer = async () => { startServer(); +process.on('SIGINT', () => { + logger.info('Shutting down server...'); + if (audioModule) { + audioModule.cleanup(); + logger.info('Audio module cleaned up'); + } + if (socketIOService) { + socketIOService.shutdown(); + logger.info('WebSocket server shut down'); + } + process.exit(0); +}); + let messageCount = 0; process.on('uncaughtException', (err) => { if (!err.message.includes('fetch failed')) { diff --git a/api/server/services/Files/Audio/AudioSocketModule.js b/api/server/services/Files/Audio/AudioSocketModule.js new file mode 100644 index 0000000000..527b22ed07 --- /dev/null +++ b/api/server/services/Files/Audio/AudioSocketModule.js @@ -0,0 +1,40 @@ +const { AudioHandler } = require('./WebRTCHandler'); +const { logger } = require('~/config'); + +class AudioSocketModule { + constructor(socketIOService) { + this.socketIOService = socketIOService; + this.audioHandler = new AudioHandler(); + + this.moduleId = 'audio-handler'; + this.registerHandlers(); + } + + registerHandlers() { + this.socketIOService.registerModule(this.moduleId, { + connection: (socket) => this.handleConnection(socket), + disconnect: (socket) => this.handleDisconnect(socket), + }); + } + + handleConnection(socket) { + // Register WebRTC-specific event handlers for this socket + this.audioHandler.registerSocketHandlers(socket, this.config); + + logger.debug(`Audio handler registered for client: ${socket.id}`); + } + + handleDisconnect(socket) { + // Cleanup audio resources for disconnected client + this.audioHandler.cleanup(socket.id); + logger.debug(`Audio handler cleaned up for client: ${socket.id}`); + } + + // Used for app shutdown + cleanup() { + this.audioHandler.cleanupAll(); + this.socketIOService.unregisterModule(this.moduleId); + } +} + +module.exports = { AudioSocketModule }; diff --git a/api/server/services/Files/Audio/WebRTCHandler.js b/api/server/services/Files/Audio/WebRTCHandler.js new file mode 100644 index 0000000000..cc46998e03 --- /dev/null +++ b/api/server/services/Files/Audio/WebRTCHandler.js @@ -0,0 +1,179 @@ +const { RTCPeerConnection, RTCIceCandidate, MediaStream } = require('wrtc'); +const { logger } = require('~/config'); + +class WebRTCConnection { + constructor(socket, config) { + this.socket = socket; + this.config = config; + this.peerConnection = null; + this.audioTransceiver = null; + this.pendingCandidates = []; + this.state = 'idle'; + } + + async handleOffer(offer) { + try { + if (!this.peerConnection) { + this.peerConnection = new RTCPeerConnection(this.config.rtcConfig); + this.setupPeerConnectionListeners(); + } + + await this.peerConnection.setRemoteDescription(offer); + + const mediaStream = new MediaStream(); + + this.audioTransceiver = this.peerConnection.addTransceiver('audio', { + direction: 'sendrecv', + streams: [mediaStream], + }); + + const answer = await this.peerConnection.createAnswer(); + await this.peerConnection.setLocalDescription(answer); + this.socket.emit('webrtc-answer', answer); + } catch (error) { + logger.error(`Error handling offer: ${error}`); + this.socket.emit('webrtc-error', { + message: error.message, + code: 'OFFER_ERROR', + }); + } + } + + setupPeerConnectionListeners() { + if (!this.peerConnection) { + return; + } + + this.peerConnection.ontrack = ({ track }) => { + logger.info(`Received ${track.kind} track from client`); + + if (track.kind === 'audio') { + this.handleIncomingAudio(track); + } + + track.onended = () => { + logger.info(`${track.kind} track ended`); + }; + }; + + this.peerConnection.onicecandidate = ({ candidate }) => { + if (candidate) { + this.socket.emit('icecandidate', candidate); + } + }; + + this.peerConnection.onconnectionstatechange = () => { + if (!this.peerConnection) { + return; + } + + const state = this.peerConnection.connectionState; + logger.info(`Connection state changed to ${state}`); + this.state = state; + + if (state === 'failed' || state === 'closed') { + this.cleanup(); + } + }; + } + + handleIncomingAudio(track) { + if (this.peerConnection) { + const stream = new MediaStream([track]); + this.peerConnection.addTrack(track, stream); + } + } + + async addIceCandidate(candidate) { + try { + if (this.peerConnection?.remoteDescription) { + if (candidate && candidate.candidate) { + await this.peerConnection.addIceCandidate(new RTCIceCandidate(candidate)); + } else { + logger.warn('Invalid ICE candidate'); + } + } else { + this.pendingCandidates.push(candidate); + } + } catch (error) { + logger.error(`Error adding ICE candidate: ${error}`); + } + } + + cleanup() { + if (this.peerConnection) { + try { + this.peerConnection.close(); + } catch (error) { + logger.error(`Error closing peer connection: ${error}`); + } + this.peerConnection = null; + } + + this.audioTransceiver = null; + this.pendingCandidates = []; + this.state = 'idle'; + } +} + +class AudioHandler { + constructor() { + this.connections = new Map(); + this.defaultRTCConfig = { + iceServers: [ + { + urls: ['stun:stun.l.google.com:19302', 'stun:stun1.l.google.com:19302'], + }, + ], + iceCandidatePoolSize: 10, + bundlePolicy: 'max-bundle', + rtcpMuxPolicy: 'require', + }; + } + + registerSocketHandlers(socket) { + const rtcConfig = { + rtcConfig: this.defaultRTCConfig, + }; + + const rtcConnection = new WebRTCConnection(socket, rtcConfig); + this.connections.set(socket.id, rtcConnection); + + socket.on('webrtc-offer', (offer) => { + logger.debug(`Received WebRTC offer from ${socket.id}`); + rtcConnection.handleOffer(offer); + }); + + socket.on('icecandidate', (candidate) => { + rtcConnection.addIceCandidate(candidate); + }); + + socket.on('vad-status', (status) => { + logger.debug(`VAD status from ${socket.id}: ${JSON.stringify(status)}`); + }); + + socket.on('disconnect', () => { + rtcConnection.cleanup(); + this.connections.delete(socket.id); + }); + + return rtcConnection; + } + + cleanup(socketId) { + const connection = this.connections.get(socketId); + if (connection) { + connection.cleanup(); + this.connections.delete(socketId); + } + } + + cleanupAll() { + for (const connection of this.connections.values()) { + connection.cleanup(); + } + this.connections.clear(); + } +} + +module.exports = { AudioHandler, WebRTCConnection }; diff --git a/api/server/services/WebSocket/WebSocketServer.js b/api/server/services/WebSocket/WebSocketServer.js index 816f77e355..7a0d348566 100644 --- a/api/server/services/WebSocket/WebSocketServer.js +++ b/api/server/services/WebSocket/WebSocketServer.js @@ -1,138 +1,8 @@ const { Server } = require('socket.io'); -const { RTCPeerConnection, RTCIceCandidate, MediaStream } = require('wrtc'); - -class WebRTCConnection { - constructor(socket, config) { - this.socket = socket; - this.config = config; - this.peerConnection = null; - this.audioTransceiver = null; - this.pendingCandidates = []; - this.state = 'idle'; - this.log = config.log || console.log; - } - - async handleOffer(offer) { - try { - if (!this.peerConnection) { - this.peerConnection = new RTCPeerConnection(this.config.rtcConfig); - this.setupPeerConnectionListeners(); - } - - await this.peerConnection.setRemoteDescription(offer); - - const mediaStream = new MediaStream(); - - this.audioTransceiver = this.peerConnection.addTransceiver('audio', { - direction: 'sendrecv', - streams: [mediaStream], - }); - - const answer = await this.peerConnection.createAnswer(); - await this.peerConnection.setLocalDescription(answer); - this.socket.emit('webrtc-answer', answer); - } catch (error) { - this.log(`Error handling offer: ${error}`, 'error'); - this.socket.emit('webrtc-error', { - message: error.message, - code: 'OFFER_ERROR', - }); - } - } - - setupPeerConnectionListeners() { - if (!this.peerConnection) { - return; - } - - this.peerConnection.ontrack = ({ track }) => { - this.log(`Received ${track.kind} track from client`); - - if (track.kind === 'audio') { - this.handleIncomingAudio(track); - } - - track.onended = () => { - this.log(`${track.kind} track ended`); - }; - }; - - this.peerConnection.onicecandidate = ({ candidate }) => { - if (candidate) { - this.socket.emit('icecandidate', candidate); - } - }; - - this.peerConnection.onconnectionstatechange = () => { - if (!this.peerConnection) { - return; - } - - const state = this.peerConnection.connectionState; - this.log(`Connection state changed to ${state}`); - this.state = state; - - if (state === 'failed' || state === 'closed') { - this.cleanup(); - } - }; - } - - handleIncomingAudio(track) { - if (this.peerConnection) { - const stream = new MediaStream([track]); - this.peerConnection.addTrack(track, stream); - } - } - - async addIceCandidate(candidate) { - try { - if (this.peerConnection?.remoteDescription) { - if (candidate && candidate.candidate) { - await this.peerConnection.addIceCandidate(new RTCIceCandidate(candidate)); - } else { - this.log('Invalid ICE candidate', 'warn'); - } - } else { - this.pendingCandidates.push(candidate); - } - } catch (error) { - this.log(`Error adding ICE candidate: ${error}`, 'error'); - } - } - - cleanup() { - if (this.peerConnection) { - try { - this.peerConnection.close(); - } catch (error) { - this.log(`Error closing peer connection: ${error}`, 'error'); - } - this.peerConnection = null; - } - - this.audioTransceiver = null; - this.pendingCandidates = []; - this.state = 'idle'; - } -} +const { logger } = require('~/config'); class SocketIOService { - constructor(httpServer, config = {}) { - this.config = { - rtcConfig: { - iceServers: [ - { - urls: ['stun:stun.l.google.com:19302', 'stun:stun1.l.google.com:19302'], - }, - ], - iceCandidatePoolSize: 10, - bundlePolicy: 'max-bundle', - rtcpMuxPolicy: 'require', - }, - ...config, - }; - + constructor(httpServer) { this.io = new Server(httpServer, { path: '/socket.io', cors: { @@ -142,50 +12,77 @@ class SocketIOService { }); this.connections = new Map(); + this.eventHandlers = new Map(); this.setupSocketHandlers(); } setupSocketHandlers() { this.io.on('connection', (socket) => { this.log(`Client connected: ${socket.id}`); + this.connections.set(socket.id, socket); - const rtcConnection = new WebRTCConnection(socket, { - ...this.config, - log: this.log.bind(this), - }); - this.connections.set(socket.id, rtcConnection); - - socket.on('webrtc-offer', (offer) => { - this.log(`Received WebRTC offer from ${socket.id}`); - rtcConnection.handleOffer(offer); - }); - - socket.on('icecandidate', (candidate) => { - rtcConnection.addIceCandidate(candidate); - }); - - socket.on('vad-status', (status) => { - this.log(`VAD status from ${socket.id}: ${JSON.stringify(status)}`); - }); + // Emit connection event for modules to handle + this.emitEvent('connection', socket); socket.on('disconnect', () => { this.log(`Client disconnected: ${socket.id}`); - rtcConnection.cleanup(); + this.emitEvent('disconnect', socket); this.connections.delete(socket.id); }); }); } + // Register a module to handle specific events + registerModule(moduleId, eventHandlers) { + for (const [eventName, handler] of Object.entries(eventHandlers)) { + if (!this.eventHandlers.has(eventName)) { + this.eventHandlers.set(eventName, new Map()); + } + + this.eventHandlers.get(eventName).set(moduleId, handler); + + // If this is a socket event, register it on all existing connections + if (eventName !== 'connection' && eventName !== 'disconnect') { + for (const socket of this.connections.values()) { + socket.on(eventName, (...args) => { + handler(socket, ...args); + }); + } + } + } + } + + // Unregister a module + unregisterModule(moduleId) { + for (const handlers of this.eventHandlers.values()) { + handlers.delete(moduleId); + } + } + + // Emit an event to all registered handlers + emitEvent(eventName, ...args) { + const handlers = this.eventHandlers.get(eventName); + if (handlers) { + for (const handler of handlers.values()) { + handler(...args); + } + } + } + log(message, level = 'info') { const timestamp = new Date().toISOString(); - console.log(`[WebRTC ${timestamp}] [${level.toUpperCase()}] ${message}`); + + try { + logger.debug(`[WebSocket] ${message}`, level); + } catch (error) { + console.log(`[WebSocket ${timestamp}] [${level.toUpperCase()}] ${message}`); + console.error(`[WebSocket ${timestamp}] [ERROR] Error while logging: ${error.message}`); + } } shutdown() { - for (const connection of this.connections.values()) { - connection.cleanup(); - } this.connections.clear(); + this.eventHandlers.clear(); this.io.close(); } }