diff --git a/api/package.json b/api/package.json
index 2a2c8be6de..d6694f5403 100644
--- a/api/package.json
+++ b/api/package.json
@@ -111,6 +111,7 @@
"winston": "^3.11.0",
"winston-daily-rotate-file": "^4.7.1",
"youtube-transcript": "^1.2.1",
+ "ws": "^8.18.0",
"zod": "^3.22.4"
},
"devDependencies": {
diff --git a/api/server/index.js b/api/server/index.js
index 4a428789dd..a2c3c674ec 100644
--- a/api/server/index.js
+++ b/api/server/index.js
@@ -4,6 +4,7 @@ require('module-alias')({ base: path.resolve(__dirname, '..') });
const cors = require('cors');
const axios = require('axios');
const express = require('express');
+const { createServer } = require('http');
const compression = require('compression');
const passport = require('passport');
const mongoSanitize = require('express-mongo-sanitize');
@@ -14,6 +15,7 @@ const { connectDb, indexSync } = require('~/lib/db');
const { isEnabled } = require('~/server/utils');
const { ldapLogin } = require('~/strategies');
const { logger } = require('~/config');
+const { WebSocketService } = require('./services/WebSocket/WebSocketServer');
const validateImageRequest = require('./middleware/validateImageRequest');
const errorController = require('./controllers/ErrorController');
const configureSocialLogins = require('./socialLogins');
@@ -37,7 +39,18 @@ const startServer = async () => {
await indexSync();
const app = express();
+ const server = createServer(app);
+
app.disable('x-powered-by');
+ app.use(
+ cors({
+ origin: true,
+ credentials: true,
+ }),
+ );
+
+ new WebSocketService(server);
+
await AppService(app);
const indexPath = path.join(app.locals.paths.dist, 'index.html');
@@ -110,6 +123,7 @@ const startServer = async () => {
app.use('/api/agents', routes.agents);
app.use('/api/banner', routes.banner);
app.use('/api/bedrock', routes.bedrock);
+ app.use('/api/websocket', routes.websocket);
app.use('/api/tags', routes.tags);
@@ -127,7 +141,7 @@ const startServer = async () => {
res.send(updatedIndexHtml);
});
- app.listen(port, host, () => {
+ server.listen(port, host, () => {
if (host == '0.0.0.0') {
logger.info(
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
@@ -135,6 +149,8 @@ const startServer = async () => {
} else {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
+
+ logger.info(`WebSocket endpoint: ws://${host}:${port}`);
});
};
diff --git a/api/server/routes/index.js b/api/server/routes/index.js
index 4b34029c7b..fa52f87723 100644
--- a/api/server/routes/index.js
+++ b/api/server/routes/index.js
@@ -2,6 +2,7 @@ const assistants = require('./assistants');
const categories = require('./categories');
const tokenizer = require('./tokenizer');
const endpoints = require('./endpoints');
+const websocket = require('./websocket');
const staticRoute = require('./static');
const messages = require('./messages');
const presets = require('./presets');
@@ -15,6 +16,7 @@ const models = require('./models');
const convos = require('./convos');
const config = require('./config');
const agents = require('./agents');
+const banner = require('./banner');
const roles = require('./roles');
const oauth = require('./oauth');
const files = require('./files');
@@ -25,7 +27,6 @@ const edit = require('./edit');
const keys = require('./keys');
const user = require('./user');
const ask = require('./ask');
-const banner = require('./banner');
module.exports = {
ask,
@@ -39,6 +40,7 @@ module.exports = {
files,
share,
agents,
+ banner,
bedrock,
convos,
search,
@@ -50,10 +52,10 @@ module.exports = {
presets,
balance,
messages,
+ websocket,
endpoints,
tokenizer,
assistants,
categories,
staticRoute,
- banner,
};
diff --git a/api/server/routes/websocket.js b/api/server/routes/websocket.js
new file mode 100644
index 0000000000..82d487f593
--- /dev/null
+++ b/api/server/routes/websocket.js
@@ -0,0 +1,18 @@
+const express = require('express');
+const optionalJwtAuth = require('~/server/middleware/optionalJwtAuth');
+const router = express.Router();
+
+router.get('/', optionalJwtAuth, async (req, res) => {
+ const isProduction = process.env.NODE_ENV === 'production';
+ const useSSL = isProduction && process.env.SERVER_DOMAIN?.startsWith('https');
+
+ const protocol = useSSL ? 'wss' : 'ws';
+ const serverDomain = process.env.SERVER_DOMAIN
+ ? process.env.SERVER_DOMAIN.replace(/^https?:\/\//, '')
+ : req.headers.host;
+ const wsUrl = `${protocol}://${serverDomain}/ws`;
+
+ res.json({ url: wsUrl });
+});
+
+module.exports = router;
diff --git a/api/server/services/WebSocket/WebSocketServer.js b/api/server/services/WebSocket/WebSocketServer.js
new file mode 100644
index 0000000000..602e20851e
--- /dev/null
+++ b/api/server/services/WebSocket/WebSocketServer.js
@@ -0,0 +1,70 @@
+const { WebSocketServer } = require('ws');
+const fs = require('fs');
+const path = require('path');
+
+module.exports.WebSocketService = class {
+ constructor(server) {
+ this.wss = new WebSocketServer({ server, path: '/ws' });
+ this.log('Server initialized');
+ this.clientAudioBuffers = new Map();
+ this.setupHandlers();
+ }
+
+ log(msg) {
+ console.log(`[WSS ${new Date().toISOString()}] ${msg}`);
+ }
+
+ setupHandlers() {
+ this.wss.on('connection', (ws) => {
+ const clientId = Date.now().toString();
+ this.clientAudioBuffers.set(clientId, []);
+
+ this.log(`Client connected: ${clientId}`);
+
+ ws.on('message', async (raw) => {
+ let message;
+ try {
+ message = JSON.parse(raw);
+ } catch {
+ return;
+ }
+
+ if (message.type === 'audio-chunk') {
+ if (!this.clientAudioBuffers.has(clientId)) {
+ this.clientAudioBuffers.set(clientId, []);
+ }
+ this.clientAudioBuffers.get(clientId).push(message.data);
+ }
+
+ if (message.type === 'request-response') {
+ const filePath = path.join(__dirname, './assets/response.mp3');
+ const audioFile = fs.readFileSync(filePath);
+ ws.send(JSON.stringify({ type: 'audio-response', data: audioFile.toString('base64') }));
+ }
+
+ if (message.type === 'call-ended') {
+ const allChunks = this.clientAudioBuffers.get(clientId);
+ this.writeAudioFile(clientId, allChunks);
+ this.clientAudioBuffers.delete(clientId);
+ }
+ });
+
+ ws.on('close', () => {
+ this.log(`Client disconnected: ${clientId}`);
+ this.clientAudioBuffers.delete(clientId);
+ });
+ });
+ }
+
+ writeAudioFile(clientId, base64Chunks) {
+ if (!base64Chunks || base64Chunks.length === 0) {
+ return;
+ }
+ const filePath = path.join(__dirname, `recorded_${clientId}.webm`);
+ const buffer = Buffer.concat(
+ base64Chunks.map((chunk) => Buffer.from(chunk.split(',')[1], 'base64')),
+ );
+ fs.writeFileSync(filePath, buffer);
+ this.log(`Saved audio to ${filePath}`);
+ }
+};
diff --git a/client/src/components/Chat/Input/Call.tsx b/client/src/components/Chat/Input/Call.tsx
new file mode 100644
index 0000000000..c79877374f
--- /dev/null
+++ b/client/src/components/Chat/Input/Call.tsx
@@ -0,0 +1,52 @@
+import { useRecoilState } from 'recoil';
+import { Mic, Phone, PhoneOff } from 'lucide-react';
+import { OGDialog, OGDialogContent, Button } from '~/components';
+import { useWebRTC, useWebSocket, useCall } from '~/hooks';
+import store from '~/store';
+
+export const Call: React.FC = () => {
+ const { isConnected } = useWebSocket();
+ const { isCalling, startCall, hangUp } = useCall();
+
+ const [open, setOpen] = useRecoilState(store.callDialogOpen(0));
+
+ return (
+