import torch import torchaudio import librosa import os import json import glob from utils.g2p import PhonemeBpeTokenizer text_tokenizer = PhonemeBpeTokenizer() lang2token ={ 'zh': "[ZH]", 'ja':"[JA]", "en":"[EN]", "fr":"[FR]", "kr": "[KR]", "de": "[DE]", } LANG2CODE = { 'en': 655, 'zh': 654, } def g2p(text, language): text = text.replace("\n","").strip("") lang_token = lang2token[language] text = lang_token + text + lang_token return text_tokenizer.tokenize(text=f"{text}".strip(),language=language) class LibriSpeechTestDataset(torch.utils.data.Dataset): def __init__(self, data_dir=None, use_vocos=False): self.data_dir = './ref_3s_new_diffprompt' self.wav_list = [] self.transcripts = {} # load json file with open('./ref_dur_3_test_merge_1pspk_with_punc_refmeta_normwav_fix_refuid_new_diffprompt.json', 'r') as f: json_data = f.read() data = json.loads(json_data) test_data = data["test_cases"] self.output_path = [] for wav_info in test_data: wav_path =os.path.join(self.data_dir, wav_info["wav_path"].split('/')[-1]) self.wav_list.append(wav_path) # print(wav_info["wav_path"]) wav_path = wav_info["wav_path"].split('/')[-1][:-4] self.transcripts[wav_path] = wav_info["text"] + " " + wav_info["target_text"] # print(self.transcripts[wav_path]) output_file_name = wav_info["uid"] + '.wav' self.output_path.append(output_file_name) def __len__(self): return len(self.wav_list) def __getitem__(self, idx): wav_file = self.wav_list[idx] transcript = self.transcripts[os.path.basename(wav_file)[:-4]] # remove punctuation transcript = ''.join(e for e in transcript if e.isalnum() or e.isspace()).lower() orig_transcript = transcript transcript = g2p(transcript, 'en')[1] transcript = [LANG2CODE['en']] + transcript transcript = torch.tensor(transcript, dtype=torch.long) speech, _ = librosa.load(wav_file, sr=16000) # resample to 24k speech = librosa.resample(speech, orig_sr=16000, target_sr=24000) speech = torch.tensor(speech, dtype=torch.float32) return { 'speech': speech, 'phone_ids': transcript, 'transcript': orig_transcript, 'output_path': self.output_path[idx], } class ValleInference(torch.nn.Module): def __init__(self, ar_model=None, nar_model=None, use_vocos=False): super().__init__() self.device = "cuda" from models.valle_ar import ValleAR self.ar_model = ValleAR( phone_vocab_size=656, target_vocab_size=1024, pad_token_id=1680, bos_target_id=1681, eos_target_id=1682, bos_phone_id=1683, eos_phone_id=1684, ) # self.ar_model.load_state_dict(torch.load('/mnt/petrelfs/hehaorui/jiaqi/data/ckpt/valle_ar/vallex_ar_mls/checkpoint/epoch-0011_step-0300000_loss-2.350145/pytorch_model.bin')) self.ar_model.load_state_dict(torch.load('valle_ar.bin')) self.ar_model.eval().to(self.device) from models.valle_nar import ValleNAR self.nar_model = ValleNAR( phone_vocab_size=656, target_vocab_size=1024, pad_token_id=1680, bos_target_id=1681, eos_target_id=1682, bos_phone_id=1683, eos_phone_id=1684, ) self.nar_model.load_state_dict(torch.load('valle_nar.bin')) self.nar_model.eval().to(self.device) from encodec import EncodecModel self.codec_encoder = EncodecModel.encodec_model_24khz() self.codec_encoder.set_target_bandwidth(6.0) self.codec_encoder.to(self.device) if use_vocos: from vocos import Vocos self.codec_decoder = Vocos.from_pretrained("charactr/vocos-encodec-24khz") self.codec_decoder.to(self.device) self.use_vocos = use_vocos def decode(self, vq_ids): '''vq_ids.shape: [8, B, T], returns: [B, 1, T*320]''' if not self.use_vocos: return self.codec_encoder.decode([(vq_ids.transpose(0,1), None)]) else: features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1)) bandwidth_id = torch.tensor([2], device=vq_ids.device) return self.codec_decoder.decode(features, bandwidth_id=bandwidth_id).unsqueeze(0) def forward(self, batch, temperature=1.1): for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(self.device) with torch.no_grad(): # inference codec '''Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') speech: [B, T] speech_len: [B] phone_ids: [B, T] phone_lens: [B] ''' vq_id = self.codec_encoder.encode(batch['speech'].unsqueeze(1)) vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(0,1) # grid search # for top_p in [0.8, 0.85, 0.9, 0.95, 1.0]: # for temperature in [1.1, 1.2, 1.3, 1.4]: # ar_vq_ids = self.ar_model.sample_hf( # batch['phone_ids'], # vq_id[0, :, :225], # top_p=top_p, # top_k=1024, # temperature=temperature, # ) # recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0)) # torchaudio.save(f'recovered_audio_ar_{top_p}_{temperature}.wav', recovered_audio_ar[0].cpu(), 24000) # print(f'recovered_audio_ar_{top_p}_{temperature}.wav') ar_vq_ids = self.ar_model.sample_hf( batch['phone_ids'], vq_id[0, :, :225], top_p=0.9, top_k=1024, temperature=1.2, # temperature=1.13, ) # [B, T] recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0)) torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000) nar_vq_ids = self.nar_model.sample_hf( phone_ids=batch['phone_ids'], prompt_ids=vq_id[:,:,:225], # first_stage_ids=vq_id[:,:,225:], first_stage_ids=ar_vq_ids, ) nar_vq_ids = torch.cat([vq_id[..., :225], nar_vq_ids], dim=-1) recovered_audio = self.decode(nar_vq_ids) torchaudio.save('recovered_audio_nar.wav', recovered_audio[0].cpu(), 24000) # breakpoint() return recovered_audio class LibriSpeechDevDataset(torch.utils.data.Dataset): def __init__(self, data_dir=None, use_vocos=False): # self.data_dir = '/mnt/petrelfs/hehaorui/jiaqi/LibriSpeech/test-clean/908/31957' self.data_dir = '/mnt/petrelfs/hehaorui/jiaqi/LibriSpeech/test-clean/8224/274384' # self.data_dir = '/mnt/petrelfs/hehaorui/jiaqi/LibriSpeech/test-clean/5683/32866' self.wav_list = glob.glob(self.data_dir + '/*.flac') + glob.glob(self.data_dir + '/*.wav') self.transcript_file = glob.glob(self.data_dir + '/*.txt')[0] self.transcripts = {} with open(self.transcript_file, 'r') as f: for line in f: line = line.strip().split() self.transcripts[line[0]] = ' '.join(line[1:]) def __len__(self): return len(self.wav_list) def __getitem__(self, idx): wav_file = self.wav_list[idx] transcript = self.transcripts[os.path.basename(wav_file)[:-5]] orig_transcript = transcript transcript = g2p(transcript, 'en')[1] transcript = torch.tensor(transcript, dtype=torch.long) speech, _ = librosa.load(wav_file, sr=16000) # resample to 24k speech = librosa.resample(speech, orig_sr=16000, target_sr=24000) speech = torch.tensor(speech, dtype=torch.float32) return { 'speech': speech, 'phone_ids': transcript, 'transcript': orig_transcript, } # dataset = LibriSpeechDevDataset() dataset = LibriSpeechTestDataset() inference = ValleInference(use_vocos=False) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) import tqdm for batch in tqdm.tqdm(dataloader): # if batch['speech'].shape[1] <= 24000 * 4: # continue print(batch['transcript'][0].lower()) recovered_audio = inference(batch) # regenerate if audio is all silence # if torch.sum(recovered_audio[0].abs()) <= 0.1: # print('regenerating audio') # recovered_audio = inference(batch, temperature=1.2)[..., 24000*3:] uid = batch['output_path'][0] save_path = os.path.join('./inference_retrain', uid) # save_path = 'a.wav' torchaudio.save(save_path, recovered_audio[0].cpu(), 24000) print(f'saved to {save_path}') # breakpoint() # save gt # torchaudio.save('./gt.wav', batch['speech'].cpu(), 24000) # print(f'saved to ./gt.wav') # breakpoint()