valle_demo / valle_inference.py
HarryHe's picture
add files
f7c417a
import torch
import torchaudio
import librosa
import os
import json
import glob
from utils.g2p import PhonemeBpeTokenizer
text_tokenizer = PhonemeBpeTokenizer()
lang2token ={
'zh': "[ZH]",
'ja':"[JA]",
"en":"[EN]",
"fr":"[FR]",
"kr": "[KR]",
"de": "[DE]",
}
LANG2CODE = {
'en': 655,
'zh': 654,
}
def g2p(text, language):
text = text.replace("\n","").strip("")
lang_token = lang2token[language]
text = lang_token + text + lang_token
return text_tokenizer.tokenize(text=f"{text}".strip(),language=language)
class LibriSpeechTestDataset(torch.utils.data.Dataset):
def __init__(self, data_dir=None, use_vocos=False):
self.data_dir = './ref_3s_new_diffprompt'
self.wav_list = []
self.transcripts = {}
# load json file
with open('./ref_dur_3_test_merge_1pspk_with_punc_refmeta_normwav_fix_refuid_new_diffprompt.json', 'r') as f:
json_data = f.read()
data = json.loads(json_data)
test_data = data["test_cases"]
self.output_path = []
for wav_info in test_data:
wav_path =os.path.join(self.data_dir, wav_info["wav_path"].split('/')[-1])
self.wav_list.append(wav_path)
# print(wav_info["wav_path"])
wav_path = wav_info["wav_path"].split('/')[-1][:-4]
self.transcripts[wav_path] = wav_info["text"] + " " + wav_info["target_text"]
# print(self.transcripts[wav_path])
output_file_name = wav_info["uid"] + '.wav'
self.output_path.append(output_file_name)
def __len__(self):
return len(self.wav_list)
def __getitem__(self, idx):
wav_file = self.wav_list[idx]
transcript = self.transcripts[os.path.basename(wav_file)[:-4]]
# remove punctuation
transcript = ''.join(e for e in transcript if e.isalnum() or e.isspace()).lower()
orig_transcript = transcript
transcript = g2p(transcript, 'en')[1]
transcript = [LANG2CODE['en']] + transcript
transcript = torch.tensor(transcript, dtype=torch.long)
speech, _ = librosa.load(wav_file, sr=16000)
# resample to 24k
speech = librosa.resample(speech, orig_sr=16000, target_sr=24000)
speech = torch.tensor(speech, dtype=torch.float32)
return {
'speech': speech,
'phone_ids': transcript,
'transcript': orig_transcript,
'output_path': self.output_path[idx],
}
class ValleInference(torch.nn.Module):
def __init__(self, ar_model=None, nar_model=None, use_vocos=False):
super().__init__()
self.device = "cuda"
from models.valle_ar import ValleAR
self.ar_model = ValleAR(
phone_vocab_size=656,
target_vocab_size=1024,
pad_token_id=1680,
bos_target_id=1681,
eos_target_id=1682,
bos_phone_id=1683,
eos_phone_id=1684,
)
# self.ar_model.load_state_dict(torch.load('/mnt/petrelfs/hehaorui/jiaqi/data/ckpt/valle_ar/vallex_ar_mls/checkpoint/epoch-0011_step-0300000_loss-2.350145/pytorch_model.bin'))
self.ar_model.load_state_dict(torch.load('valle_ar.bin'))
self.ar_model.eval().to(self.device)
from models.valle_nar import ValleNAR
self.nar_model = ValleNAR(
phone_vocab_size=656,
target_vocab_size=1024,
pad_token_id=1680,
bos_target_id=1681,
eos_target_id=1682,
bos_phone_id=1683,
eos_phone_id=1684,
)
self.nar_model.load_state_dict(torch.load('valle_nar.bin'))
self.nar_model.eval().to(self.device)
from encodec import EncodecModel
self.codec_encoder = EncodecModel.encodec_model_24khz()
self.codec_encoder.set_target_bandwidth(6.0)
self.codec_encoder.to(self.device)
if use_vocos:
from vocos import Vocos
self.codec_decoder = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
self.codec_decoder.to(self.device)
self.use_vocos = use_vocos
def decode(self, vq_ids):
'''vq_ids.shape: [8, B, T],
returns: [B, 1, T*320]'''
if not self.use_vocos:
return self.codec_encoder.decode([(vq_ids.transpose(0,1), None)])
else:
features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1))
bandwidth_id = torch.tensor([2], device=vq_ids.device)
return self.codec_decoder.decode(features, bandwidth_id=bandwidth_id).unsqueeze(0)
def forward(self, batch, temperature=1.1):
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device)
with torch.no_grad():
# inference codec
'''Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
speech: [B, T]
speech_len: [B]
phone_ids: [B, T]
phone_lens: [B]
'''
vq_id = self.codec_encoder.encode(batch['speech'].unsqueeze(1))
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(0,1)
# grid search
# for top_p in [0.8, 0.85, 0.9, 0.95, 1.0]:
# for temperature in [1.1, 1.2, 1.3, 1.4]:
# ar_vq_ids = self.ar_model.sample_hf(
# batch['phone_ids'],
# vq_id[0, :, :225],
# top_p=top_p,
# top_k=1024,
# temperature=temperature,
# )
# recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0))
# torchaudio.save(f'recovered_audio_ar_{top_p}_{temperature}.wav', recovered_audio_ar[0].cpu(), 24000)
# print(f'recovered_audio_ar_{top_p}_{temperature}.wav')
ar_vq_ids = self.ar_model.sample_hf(
batch['phone_ids'],
vq_id[0, :, :225],
top_p=0.9,
top_k=1024,
temperature=1.2,
# temperature=1.13,
) # [B, T]
recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0))
torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000)
nar_vq_ids = self.nar_model.sample_hf(
phone_ids=batch['phone_ids'],
prompt_ids=vq_id[:,:,:225],
# first_stage_ids=vq_id[:,:,225:],
first_stage_ids=ar_vq_ids,
)
nar_vq_ids = torch.cat([vq_id[..., :225], nar_vq_ids], dim=-1)
recovered_audio = self.decode(nar_vq_ids)
torchaudio.save('recovered_audio_nar.wav', recovered_audio[0].cpu(), 24000)
# breakpoint()
return recovered_audio
class LibriSpeechDevDataset(torch.utils.data.Dataset):
def __init__(self, data_dir=None, use_vocos=False):
# self.data_dir = '/mnt/petrelfs/hehaorui/jiaqi/LibriSpeech/test-clean/908/31957'
self.data_dir = '/mnt/petrelfs/hehaorui/jiaqi/LibriSpeech/test-clean/8224/274384'
# self.data_dir = '/mnt/petrelfs/hehaorui/jiaqi/LibriSpeech/test-clean/5683/32866'
self.wav_list = glob.glob(self.data_dir + '/*.flac') + glob.glob(self.data_dir + '/*.wav')
self.transcript_file = glob.glob(self.data_dir + '/*.txt')[0]
self.transcripts = {}
with open(self.transcript_file, 'r') as f:
for line in f:
line = line.strip().split()
self.transcripts[line[0]] = ' '.join(line[1:])
def __len__(self):
return len(self.wav_list)
def __getitem__(self, idx):
wav_file = self.wav_list[idx]
transcript = self.transcripts[os.path.basename(wav_file)[:-5]]
orig_transcript = transcript
transcript = g2p(transcript, 'en')[1]
transcript = torch.tensor(transcript, dtype=torch.long)
speech, _ = librosa.load(wav_file, sr=16000)
# resample to 24k
speech = librosa.resample(speech, orig_sr=16000, target_sr=24000)
speech = torch.tensor(speech, dtype=torch.float32)
return {
'speech': speech,
'phone_ids': transcript,
'transcript': orig_transcript,
}
# dataset = LibriSpeechDevDataset()
dataset = LibriSpeechTestDataset()
inference = ValleInference(use_vocos=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
import tqdm
for batch in tqdm.tqdm(dataloader):
# if batch['speech'].shape[1] <= 24000 * 4:
# continue
print(batch['transcript'][0].lower())
recovered_audio = inference(batch)
# regenerate if audio is all silence
# if torch.sum(recovered_audio[0].abs()) <= 0.1:
# print('regenerating audio')
# recovered_audio = inference(batch, temperature=1.2)[..., 24000*3:]
uid = batch['output_path'][0]
save_path = os.path.join('./inference_retrain', uid)
# save_path = 'a.wav'
torchaudio.save(save_path, recovered_audio[0].cpu(), 24000)
print(f'saved to {save_path}')
# breakpoint()
# save gt
# torchaudio.save('./gt.wav', batch['speech'].cpu(), 24000)
# print(f'saved to ./gt.wav')
# breakpoint()