🌊 feat: add Deepgram support for STT providers

This commit is contained in:
Marco Beretta 2024-11-23 12:17:53 +01:00
parent 56b60cf863
commit 25d51eff31
No known key found for this signature in database
GPG key ID: D918033D8E74CC11
5 changed files with 244 additions and 19 deletions

View file

@ -2,6 +2,7 @@ const axios = require('axios');
const fs = require('fs').promises;
const FormData = require('form-data');
const { Readable } = require('stream');
const { createClient } = require('@deepgram/sdk');
const { extractEnvVariable, STTProviders } = require('librechat-data-provider');
const { getCustomConfig } = require('~/server/services/Config');
const { genAzureEndpoint } = require('~/utils');
@ -18,10 +19,14 @@ class STTService {
*/
constructor(customConfig) {
this.customConfig = customConfig;
this.providerStrategies = {
this.apiStrategies = {
[STTProviders.OPENAI]: this.openAIProvider,
[STTProviders.AZURE_OPENAI]: this.azureOpenAIProvider,
};
this.sdkStrategies = {
[STTProviders.DEEPGRAM]: this.deepgramSDKProvider,
};
}
/**
@ -153,6 +158,61 @@ class STTService {
return [url, formData, { ...headers, ...formData.getHeaders() }];
}
async deepgramSDKProvider(sttSchema, audioReadStream, audioFile) {
const apiKey = extractEnvVariable(sttSchema.apiKey) || '';
const deepgram = createClient(apiKey);
const configOptions = {
// Model parameters
model: sttSchema.model?.model,
language: sttSchema.model?.language,
detect_language: sttSchema.model?.detect_language,
version: sttSchema.model?.version,
// Formatting parameters
smart_format: sttSchema.formatting?.smart_format,
diarize: sttSchema.formatting?.diarize,
filler_words: sttSchema.formatting?.filler_words,
numerals: sttSchema.formatting?.numerals,
punctuate: sttSchema.formatting?.punctuate,
paragraphs: sttSchema.formatting?.paragraphs,
profanity_filter: sttSchema.formatting?.profanity_filter,
redact: sttSchema.formatting?.redact,
utterances: sttSchema.formatting?.utterances,
utt_split: sttSchema.formatting?.utt_split,
// Custom vocabulary parameters
replace: sttSchema.custom_vocabulary?.replace,
keywords: sttSchema.custom_vocabulary?.keywords,
// Intelligence parameters
sentiment: sttSchema.intelligence?.sentiment,
intents: sttSchema.intelligence?.intents,
topics: sttSchema.intelligence?.topics,
};
[configOptions].forEach(this.removeUndefined);
const { result, error } = await deepgram.listen.prerecorded.transcribeFile(
Buffer.isBuffer(audioFile) ? audioFile : audioReadStream,
configOptions,
);
if (error) {
throw error;
}
return result.results?.channels[0]?.alternatives[0]?.transcript || '';
}
shouldUseSDK(provider, sttSchema) {
if (provider !== STTProviders.OPENAI && provider !== STTProviders.AZURE_OPENAI) {
return true;
}
return typeof sttSchema.url === 'string' && sttSchema.url.trim().length > 0;
}
/**
* Sends an STT request to the specified provider.
* @async
@ -165,31 +225,34 @@ class STTService {
* @throws {Error} If the provider is invalid, the response status is not 200, or the response data is missing.
*/
async sttRequest(provider, sttSchema, { audioBuffer, audioFile }) {
const strategy = this.providerStrategies[provider];
const useSDK = this.shouldUseSDK(provider, sttSchema);
const strategy = useSDK ? this.sdkStrategies[provider] : this.apiStrategies[provider];
if (!strategy) {
throw new Error('Invalid provider');
throw new Error('Invalid provider or implementation');
}
const audioReadStream = Readable.from(audioBuffer);
audioReadStream.path = 'audio.wav';
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);
if (useSDK) {
return strategy.call(this, sttSchema, audioReadStream, audioFile);
} else {
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);
try {
const response = await axios.post(url, data, { headers });
if (response.status !== 200) {
throw new Error('Invalid response from the STT API');
try {
const response = await axios.post(url, data, { headers });
if (response.status !== 200) {
throw new Error('Invalid response from the STT API');
}
if (!response.data || !response.data.text) {
throw new Error('Missing data in response from the STT API');
}
return response.data.text.trim();
} catch (error) {
logger.error(`STT request failed for provider ${provider}:`, error);
throw error;
}
if (!response.data || !response.data.text) {
throw new Error('Missing data in response from the STT API');
}
return response.data.text.trim();
} catch (error) {
logger.error(`STT request failed for provider ${provider}:`, error);
throw error;
}
}

View file

@ -21,6 +21,7 @@ class TTSService {
[TTSProviders.AZURE_OPENAI]: this.azureOpenAIProvider.bind(this),
[TTSProviders.ELEVENLABS]: this.elevenLabsProvider.bind(this),
[TTSProviders.LOCALAI]: this.localAIProvider.bind(this),
[TTSProviders.ELEVENLABS]: this.elevenLabsProvider.bind(this),
};
}
@ -247,6 +248,52 @@ class TTSService {
return [url, data, headers];
}
deepgramProvider(ttsSchema, input, voice) {
const baseUrl = ttsSchema?.url || 'https://api.deepgram.com/v1/speak';
const params = {
model: ttsSchema.model,
voice: voice,
language: ttsSchema.language,
};
const queryParams = Object.entries(params)
.filter(([, value]) => value)
.map(([key, value]) => `${key}=${value}`)
.join('&');
const url = queryParams ? `${baseUrl}?${queryParams}` : baseUrl;
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
throw new Error(`Voice ${voice} is not available.`);
}
const data = {
input,
model: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
language: ttsSchema?.language,
media_settings: {
bit_rate: ttsSchema?.media_settings?.bit_rate,
sample_rate: ttsSchema?.media_settings?.sample_rate,
},
};
const headers = {
'Content-Type': 'application/json',
Authorization: `Bearer ${extractEnvVariable(ttsSchema?.apiKey)}`,
};
if (extractEnvVariable(ttsSchema.apiKey) === '') {
delete headers.Authorization;
}
return [url, data, headers];
}
/**
* Sends a TTS request to the specified provider.
* @async