Spaces:
Build error
Build error
| import torch | |
| from inference.tts.base_tts_infer import BaseTTSInfer | |
| from utils.ckpt_utils import load_ckpt | |
| from modules.portaspeech.portaspeech import PortaSpeech | |
| class TTSInference(BaseTTSInfer): | |
| def __init__(self, hparams, device=None): | |
| super().__init__(hparams, device) | |
| print("Initializing TTS model to %s" % device) | |
| self.spk_map = self.preprocessor.load_spk_map(self.data_dir) | |
| print("TTS loaded!") | |
| def build_model(self): | |
| model = PortaSpeech(self.ph_encoder, self.word_encoder) | |
| load_ckpt(model, self.hparams['work_dir'], 'model') | |
| with torch.no_grad(): | |
| model.store_inverse_all() | |
| return model | |
| def forward_model(self, inp): | |
| sample = self.input_to_batch(inp) | |
| with torch.no_grad(): | |
| output = self.model( | |
| sample['txt_tokens'], | |
| sample['word_tokens'], | |
| ph2word=sample['ph2word'], | |
| word_len=sample['word_lengths'].max(), | |
| infer=True, | |
| forward_post_glow=True, | |
| spk_id=sample.get('spk_ids') | |
| ) | |
| mel_out = output['mel_out'] | |
| wav_out = self.run_vocoder(mel_out) | |
| wav_out = wav_out.cpu().numpy() | |
| return wav_out[0] | |
| def preprocess_input(self, inp): | |
| """ | |
| :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} | |
| :return: | |
| """ | |
| preprocessor, preprocess_args = self.preprocessor, self.preprocess_args | |
| text_raw = inp['text'] | |
| item_name = inp.get('item_name', '<ITEM_NAME>') | |
| spk_name = inp.get('spk_name', '<SINGLE_SPK>') | |
| ph, txt, word, ph2word, ph_gb_word = preprocessor.txt_to_ph( | |
| preprocessor.txt_processor, text_raw, preprocess_args) | |
| word_token = self.word_encoder.encode(word) | |
| ph_token = self.ph_encoder.encode(ph) | |
| spk_id = self.spk_map[spk_name] | |
| item = {'item_name': item_name, 'text': txt, 'ph': ph, 'spk_id': spk_id, | |
| 'ph_token': ph_token, 'word_token': word_token, 'ph2word': ph2word, | |
| 'ph_words':ph_gb_word, 'words': word} | |
| item['ph_len'] = len(item['ph_token']) | |
| return item | |
| def input_to_batch(self, item): | |
| item_names = [item['item_name']] | |
| text = [item['text']] | |
| ph = [item['ph']] | |
| txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) | |
| txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) | |
| word_tokens = torch.LongTensor(item['word_token'])[None, :].to(self.device) | |
| word_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) | |
| ph2word = torch.LongTensor(item['ph2word'])[None, :].to(self.device) | |
| spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device) | |
| batch = { | |
| 'item_name': item_names, | |
| 'text': text, | |
| 'ph': ph, | |
| 'txt_tokens': txt_tokens, | |
| 'txt_lengths': txt_lengths, | |
| 'word_tokens': word_tokens, | |
| 'word_lengths': word_lengths, | |
| 'ph2word': ph2word, | |
| 'spk_ids': spk_ids, | |
| } | |
| return batch | |
| def postprocess_output(self, output): | |
| return output | |