Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import torch | |
| import librosa | |
| import soundfile | |
| import torchaudio | |
| import numpy as np | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from . import utils | |
| from . import commons | |
| from .models import SynthesizerTrn | |
| from .split_utils import split_sentence | |
| from .mel_processing import spectrogram_torch, spectrogram_torch_conv | |
| from .download_utils import load_or_download_config, load_or_download_model | |
| class TTS(nn.Module): | |
| def __init__(self, | |
| language, | |
| device='cuda:0'): | |
| super().__init__() | |
| if 'cuda' in device: | |
| assert torch.cuda.is_available() | |
| # config_path = | |
| hps = load_or_download_config(language) | |
| num_languages = hps.num_languages | |
| num_tones = hps.num_tones | |
| symbols = hps.symbols | |
| model = SynthesizerTrn( | |
| len(symbols), | |
| hps.data.filter_length // 2 + 1, | |
| hps.train.segment_size // hps.data.hop_length, | |
| n_speakers=hps.data.n_speakers, | |
| num_tones=num_tones, | |
| num_languages=num_languages, | |
| **hps.model, | |
| ).to(device) | |
| model.eval() | |
| self.model = model | |
| self.symbol_to_id = {s: i for i, s in enumerate(symbols)} | |
| self.hps = hps | |
| self.device = device | |
| # load state_dict | |
| checkpoint_dict = load_or_download_model(language, device) | |
| self.model.load_state_dict(checkpoint_dict['model'], strict=True) | |
| language = language.split('_')[0] | |
| self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model | |
| def audio_numpy_concat(segment_data_list, sr, speed=1.): | |
| audio_segments = [] | |
| for segment_data in segment_data_list: | |
| audio_segments += segment_data.reshape(-1).tolist() | |
| audio_segments += [0] * int((sr * 0.05) / speed) | |
| audio_segments = np.array(audio_segments).astype(np.float32) | |
| return audio_segments | |
| def split_sentences_into_pieces(text, language): | |
| texts = split_sentence(text, language_str=language) | |
| # print(" > Text splitted to sentences.") | |
| # print('\n'.join(texts)) | |
| # print(" > ===========================") | |
| return texts | |
| def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None): | |
| language = self.language | |
| texts = self.split_sentences_into_pieces(text, language) | |
| audio_list = [] | |
| tx = texts | |
| if pbar: | |
| tx = pbar(texts) | |
| else: | |
| if position: | |
| tx = tqdm(texts, position=position) | |
| else: | |
| tx = tqdm(texts) | |
| for t in tx: | |
| if language in ['EN', 'ZH_MIX_EN']: | |
| t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) | |
| device = self.device | |
| bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id) | |
| with torch.no_grad(): | |
| x_tst = phones.to(device).unsqueeze(0) | |
| tones = tones.to(device).unsqueeze(0) | |
| lang_ids = lang_ids.to(device).unsqueeze(0) | |
| bert = bert.to(device).unsqueeze(0) | |
| ja_bert = ja_bert.to(device).unsqueeze(0) | |
| x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) | |
| del phones | |
| speakers = torch.LongTensor([speaker_id]).to(device) | |
| audio = self.model.infer( | |
| x_tst, | |
| x_tst_lengths, | |
| speakers, | |
| tones, | |
| lang_ids, | |
| bert, | |
| ja_bert, | |
| sdp_ratio=sdp_ratio, | |
| noise_scale=noise_scale, | |
| noise_scale_w=noise_scale_w, | |
| length_scale=1. / speed, | |
| )[0][0, 0].data.cpu().float().numpy() | |
| del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers | |
| # | |
| audio_list.append(audio) | |
| torch.cuda.empty_cache() | |
| audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) | |
| if output_path is None: | |
| return audio | |
| else: | |
| soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format) |