From 25d51eff318357dd03b7943ab7ede1a7d1c847cb Mon Sep 17 00:00:00 2001 From: Marco Beretta <81851188+berry-13@users.noreply.github.com> Date: Sat, 23 Nov 2024 12:17:53 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8C=8A=20feat:=20add=20Deepgram=20support?= =?UTF-8?q?=20for=20STT=20providers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/server/services/Files/Audio/STTService.js | 99 +++++++++++++++---- api/server/services/Files/Audio/TTSService.js | 47 +++++++++ package-lock.json | 48 ++++++++- package.json | 3 + packages/data-provider/src/config.ts | 66 +++++++++++++ 5 files changed, 244 insertions(+), 19 deletions(-) diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index 84590cac11..bb32ef1b93 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -2,6 +2,7 @@ const axios = require('axios'); const fs = require('fs').promises; const FormData = require('form-data'); const { Readable } = require('stream'); +const { createClient } = require('@deepgram/sdk'); const { extractEnvVariable, STTProviders } = require('librechat-data-provider'); const { getCustomConfig } = require('~/server/services/Config'); const { genAzureEndpoint } = require('~/utils'); @@ -18,10 +19,14 @@ class STTService { */ constructor(customConfig) { this.customConfig = customConfig; - this.providerStrategies = { + this.apiStrategies = { [STTProviders.OPENAI]: this.openAIProvider, [STTProviders.AZURE_OPENAI]: this.azureOpenAIProvider, }; + + this.sdkStrategies = { + [STTProviders.DEEPGRAM]: this.deepgramSDKProvider, + }; } /** @@ -153,6 +158,61 @@ class STTService { return [url, formData, { ...headers, ...formData.getHeaders() }]; } + async deepgramSDKProvider(sttSchema, audioReadStream, audioFile) { + const apiKey = extractEnvVariable(sttSchema.apiKey) || ''; + const deepgram = createClient(apiKey); + + const configOptions = { + // Model parameters + model: sttSchema.model?.model, + language: sttSchema.model?.language, + detect_language: sttSchema.model?.detect_language, + version: sttSchema.model?.version, + + // Formatting parameters + smart_format: sttSchema.formatting?.smart_format, + diarize: sttSchema.formatting?.diarize, + filler_words: sttSchema.formatting?.filler_words, + numerals: sttSchema.formatting?.numerals, + punctuate: sttSchema.formatting?.punctuate, + paragraphs: sttSchema.formatting?.paragraphs, + profanity_filter: sttSchema.formatting?.profanity_filter, + redact: sttSchema.formatting?.redact, + utterances: sttSchema.formatting?.utterances, + utt_split: sttSchema.formatting?.utt_split, + + // Custom vocabulary parameters + replace: sttSchema.custom_vocabulary?.replace, + keywords: sttSchema.custom_vocabulary?.keywords, + + // Intelligence parameters + sentiment: sttSchema.intelligence?.sentiment, + intents: sttSchema.intelligence?.intents, + topics: sttSchema.intelligence?.topics, + }; + + [configOptions].forEach(this.removeUndefined); + + const { result, error } = await deepgram.listen.prerecorded.transcribeFile( + Buffer.isBuffer(audioFile) ? audioFile : audioReadStream, + configOptions, + ); + + if (error) { + throw error; + } + + return result.results?.channels[0]?.alternatives[0]?.transcript || ''; + } + + shouldUseSDK(provider, sttSchema) { + if (provider !== STTProviders.OPENAI && provider !== STTProviders.AZURE_OPENAI) { + return true; + } + + return typeof sttSchema.url === 'string' && sttSchema.url.trim().length > 0; + } + /** * Sends an STT request to the specified provider. * @async @@ -165,31 +225,34 @@ class STTService { * @throws {Error} If the provider is invalid, the response status is not 200, or the response data is missing. */ async sttRequest(provider, sttSchema, { audioBuffer, audioFile }) { - const strategy = this.providerStrategies[provider]; + const useSDK = this.shouldUseSDK(provider, sttSchema); + const strategy = useSDK ? this.sdkStrategies[provider] : this.apiStrategies[provider]; + if (!strategy) { - throw new Error('Invalid provider'); + throw new Error('Invalid provider or implementation'); } const audioReadStream = Readable.from(audioBuffer); audioReadStream.path = 'audio.wav'; - const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile); + if (useSDK) { + return strategy.call(this, sttSchema, audioReadStream, audioFile); + } else { + const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile); - try { - const response = await axios.post(url, data, { headers }); - - if (response.status !== 200) { - throw new Error('Invalid response from the STT API'); + try { + const response = await axios.post(url, data, { headers }); + if (response.status !== 200) { + throw new Error('Invalid response from the STT API'); + } + if (!response.data || !response.data.text) { + throw new Error('Missing data in response from the STT API'); + } + return response.data.text.trim(); + } catch (error) { + logger.error(`STT request failed for provider ${provider}:`, error); + throw error; } - - if (!response.data || !response.data.text) { - throw new Error('Missing data in response from the STT API'); - } - - return response.data.text.trim(); - } catch (error) { - logger.error(`STT request failed for provider ${provider}:`, error); - throw error; } } diff --git a/api/server/services/Files/Audio/TTSService.js b/api/server/services/Files/Audio/TTSService.js index d9b1e1d44f..8558a7eb9c 100644 --- a/api/server/services/Files/Audio/TTSService.js +++ b/api/server/services/Files/Audio/TTSService.js @@ -21,6 +21,7 @@ class TTSService { [TTSProviders.AZURE_OPENAI]: this.azureOpenAIProvider.bind(this), [TTSProviders.ELEVENLABS]: this.elevenLabsProvider.bind(this), [TTSProviders.LOCALAI]: this.localAIProvider.bind(this), + [TTSProviders.ELEVENLABS]: this.elevenLabsProvider.bind(this), }; } @@ -247,6 +248,52 @@ class TTSService { return [url, data, headers]; } + deepgramProvider(ttsSchema, input, voice) { + const baseUrl = ttsSchema?.url || 'https://api.deepgram.com/v1/speak'; + const params = { + model: ttsSchema.model, + voice: voice, + language: ttsSchema.language, + }; + + const queryParams = Object.entries(params) + .filter(([, value]) => value) + .map(([key, value]) => `${key}=${value}`) + .join('&'); + + const url = queryParams ? `${baseUrl}?${queryParams}` : baseUrl; + + if ( + ttsSchema?.voices && + ttsSchema.voices.length > 0 && + !ttsSchema.voices.includes(voice) && + !ttsSchema.voices.includes('ALL') + ) { + throw new Error(`Voice ${voice} is not available.`); + } + + const data = { + input, + model: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined, + language: ttsSchema?.language, + media_settings: { + bit_rate: ttsSchema?.media_settings?.bit_rate, + sample_rate: ttsSchema?.media_settings?.sample_rate, + }, + }; + + const headers = { + 'Content-Type': 'application/json', + Authorization: `Bearer ${extractEnvVariable(ttsSchema?.apiKey)}`, + }; + + if (extractEnvVariable(ttsSchema.apiKey) === '') { + delete headers.Authorization; + } + + return [url, data, headers]; + } + /** * Sends a TTS request to the specified provider. * @async diff --git a/package-lock.json b/package-lock.json index fbe1e0f4a7..f460bb2646 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13,6 +13,9 @@ "client", "packages/*" ], + "dependencies": { + "@deepgram/sdk": "^3.9.0" + }, "devDependencies": { "@axe-core/playwright": "^4.9.1", "@playwright/test": "^1.38.1", @@ -6635,6 +6638,44 @@ "kuler": "^2.0.0" } }, + "node_modules/@deepgram/captions": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@deepgram/captions/-/captions-1.2.0.tgz", + "integrity": "sha512-8B1C/oTxTxyHlSFubAhNRgCbQ2SQ5wwvtlByn8sDYZvdDtdn/VE2yEPZ4BvUnrKWmsbTQY6/ooLV+9Ka2qmDSQ==", + "license": "MIT", + "dependencies": { + "dayjs": "^1.11.10" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@deepgram/sdk": { + "version": "3.9.0", + "resolved": "https://registry.npmjs.org/@deepgram/sdk/-/sdk-3.9.0.tgz", + "integrity": "sha512-X/7JzoYjCObyEaPb2Dgnkwk2LwRe4bw0FJJCLdkjpnFfJCFgA9IWgRD8FEUI6/hp8dW/CqqXkGPA2Q3DIsVG8A==", + "license": "MIT", + "dependencies": { + "@deepgram/captions": "^1.1.1", + "@types/node": "^18.19.39", + "cross-fetch": "^3.1.5", + "deepmerge": "^4.3.1", + "events": "^3.3.0", + "ws": "^8.17.0" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@deepgram/sdk/node_modules/@types/node": { + "version": "18.19.65", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.65.tgz", + "integrity": "sha512-Ay5BZuO1UkTmVHzZJNvZKw/E+iB3GQABb6kijEz89w2JrfhNA+M/ebp18pfz9Gqe9ywhMC8AA8yC01lZq48J+Q==", + "license": "MIT", + "dependencies": { + "undici-types": "~5.26.4" + } + }, "node_modules/@dicebear/adventurer": { "version": "7.0.4", "resolved": "https://registry.npmjs.org/@dicebear/adventurer/-/adventurer-7.0.4.tgz", @@ -17942,6 +17983,12 @@ "url": "https://github.com/sponsors/kossnocorp" } }, + "node_modules/dayjs": { + "version": "1.11.13", + "resolved": "https://registry.npmjs.org/dayjs/-/dayjs-1.11.13.tgz", + "integrity": "sha512-oaMBel6gjolK862uaPQOVTA7q3TZhuSvuMQAAglQDOWYO9A91IrAOUJEyKVlqJlHE0vq5p5UXxzdPfMH/x6xNg==", + "license": "MIT" + }, "node_modules/debug": { "version": "4.3.7", "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", @@ -18067,7 +18114,6 @@ "version": "4.3.1", "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz", "integrity": "sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==", - "dev": true, "engines": { "node": ">=0.10.0" } diff --git a/package.json b/package.json index 989c04423a..b6233fdd0a 100644 --- a/package.json +++ b/package.json @@ -113,5 +113,8 @@ "admin/", "packages/" ] + }, + "dependencies": { + "@deepgram/sdk": "^3.9.0" } } diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 04f3faf077..9ca16938f0 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -336,11 +336,28 @@ const ttsLocalaiSchema = z.object({ backend: z.string(), }); +const ttsDeepgramSchema = z + .object({ + url: z.string().optional(), + apiKey: z.string().optional(), + voices: z.array(z.string()), + model: z.string(), + language: z.string().optional(), + media_settings: z + .object({ + bit_rate: z.number().optional(), + sample_rate: z.number().optional(), + }) + .optional(), + }) + .optional(); + const ttsSchema = z.object({ openai: ttsOpenaiSchema.optional(), azureOpenAI: ttsAzureOpenAISchema.optional(), elevenlabs: ttsElevenLabsSchema.optional(), localai: ttsLocalaiSchema.optional(), + deepgram: ttsDeepgramSchema.optional(), }); const sttOpenaiSchema = z.object({ @@ -356,9 +373,50 @@ const sttAzureOpenAISchema = z.object({ apiVersion: z.string(), }); +const sttDeepgramSchema = z.object({ + url: z.string().optional(), + apiKey: z.string().optional(), + model: z + .object({ + model: z.string().optional(), + language: z.string().optional(), + detect_language: z.boolean().optional(), + version: z.string().optional(), + }) + .optional(), + formatting: z + .object({ + smart_format: z.boolean().optional(), + diarize: z.boolean().optional(), + filler_words: z.boolean().optional(), + numerals: z.boolean().optional(), + punctuate: z.boolean().optional(), + paragraphs: z.boolean().optional(), + profanity_filter: z.boolean().optional(), + redact: z.boolean().optional(), + utterances: z.boolean().optional(), + utt_split: z.number().optional(), + }) + .optional(), + custom_vocabulary: z + .object({ + replace: z.array(z.string()).optional(), + keywords: z.array(z.string()).optional(), + }) + .optional(), + intelligence: z + .object({ + sentiment: z.boolean().optional(), + intents: z.boolean().optional(), + topics: z.boolean().optional(), + }) + .optional(), +}); + const sttSchema = z.object({ openai: sttOpenaiSchema.optional(), azureOpenAI: sttAzureOpenAISchema.optional(), + deepgram: sttDeepgramSchema.optional(), }); const speechTab = z @@ -1054,6 +1112,10 @@ export enum STTProviders { * Provider for Microsoft Azure STT */ AZURE_OPENAI = 'azureOpenAI', + /** + * Provider for Deepgram STT + */ + DEEPGRAM = 'deepgram', } export enum TTSProviders { @@ -1073,6 +1135,10 @@ export enum TTSProviders { * Provider for LocalAI TTS */ LOCALAI = 'localai', + /** + * Provider for Deepgram TTS + */ + DEEPGRAM = 'deepgram', } /** Enum for app-wide constants */