MARS5-TTS / handler.py
arnavmehta7's picture
Upload 2 files
f1a1bc0 verified
raw
history blame
1.47 kB
from typing import Dict
from pathlib import Path
import tempfile
import torch
import torchaudio
import librosa
SAMPLE_RATE = 16000
class EndpointHandler():
def __init__(self, path=""):
# Load the MARS5 model
self.mars5, self.config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)
def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
"""
Args:
data (Dict[str, bytes]):
Includes the text, audio file path, and transcript.
Returns:
Dict[str, str]: Path to the synthesized audio file.
"""
# Process input
text = data["text"]
audio_file = data["audio_file"]
transcript = data["transcript"]
# Load the reference audio
wav, sr = librosa.load(audio_file, sr=self.mars5.sr, mono=True)
wav = torch.from_numpy(wav)
# Define the configuration for the TTS model
deep_clone = True
cfg = self.config_class(deep_clone=deep_clone, rep_penalty_window=100, top_k=100, temperature=0.7, freq_penalty=3)
# Generate the synthesized audio
ar_codes, wav_out = self.mars5.tts(text, wav, transcript, cfg=cfg)
# Save the synthesized audio to a temporary file
output_path = Path(tempfile.mktemp(suffix=".wav"))
torchaudio.save(output_path, wav_out.unsqueeze(0), self.mars5.sr)
return {"synthesized_audio": str(output_path)}