|
import torch |
|
import torchaudio |
|
import librosa |
|
import os |
|
import json |
|
import glob |
|
from utils.g2p import PhonemeBpeTokenizer |
|
text_tokenizer = PhonemeBpeTokenizer() |
|
lang2token ={ |
|
'zh': "[ZH]", |
|
'ja':"[JA]", |
|
"en":"[EN]", |
|
"fr":"[FR]", |
|
"kr": "[KR]", |
|
"de": "[DE]", |
|
} |
|
LANG2CODE = { |
|
'en': 655, |
|
'zh': 654, |
|
} |
|
|
|
def g2p(text, language): |
|
text = text.replace("\n","").strip("") |
|
lang_token = lang2token[language] |
|
text = lang_token + text + lang_token |
|
return text_tokenizer.tokenize(text=f"{text}".strip(),language=language) |
|
|
|
|
|
class LibriSpeechTestDataset(torch.utils.data.Dataset): |
|
def __init__(self, data_dir=None, use_vocos=False): |
|
self.data_dir = './ref_3s_new_diffprompt' |
|
self.wav_list = [] |
|
self.transcripts = {} |
|
|
|
|
|
with open('./ref_dur_3_test_merge_1pspk_with_punc_refmeta_normwav_fix_refuid_new_diffprompt.json', 'r') as f: |
|
json_data = f.read() |
|
data = json.loads(json_data) |
|
|
|
test_data = data["test_cases"] |
|
|
|
self.output_path = [] |
|
for wav_info in test_data: |
|
wav_path =os.path.join(self.data_dir, wav_info["wav_path"].split('/')[-1]) |
|
self.wav_list.append(wav_path) |
|
|
|
wav_path = wav_info["wav_path"].split('/')[-1][:-4] |
|
self.transcripts[wav_path] = wav_info["text"] + " " + wav_info["target_text"] |
|
|
|
output_file_name = wav_info["uid"] + '.wav' |
|
self.output_path.append(output_file_name) |
|
|
|
def __len__(self): |
|
return len(self.wav_list) |
|
|
|
def __getitem__(self, idx): |
|
wav_file = self.wav_list[idx] |
|
transcript = self.transcripts[os.path.basename(wav_file)[:-4]] |
|
|
|
transcript = ''.join(e for e in transcript if e.isalnum() or e.isspace()).lower() |
|
orig_transcript = transcript |
|
transcript = g2p(transcript, 'en')[1] |
|
transcript = [LANG2CODE['en']] + transcript |
|
transcript = torch.tensor(transcript, dtype=torch.long) |
|
|
|
speech, _ = librosa.load(wav_file, sr=16000) |
|
|
|
speech = librosa.resample(speech, orig_sr=16000, target_sr=24000) |
|
speech = torch.tensor(speech, dtype=torch.float32) |
|
|
|
return { |
|
'speech': speech, |
|
'phone_ids': transcript, |
|
'transcript': orig_transcript, |
|
'output_path': self.output_path[idx], |
|
} |
|
|
|
|
|
class ValleInference(torch.nn.Module): |
|
def __init__(self, ar_model=None, nar_model=None, use_vocos=False): |
|
super().__init__() |
|
|
|
self.device = "cuda" |
|
|
|
from models.valle_ar import ValleAR |
|
self.ar_model = ValleAR( |
|
phone_vocab_size=656, |
|
target_vocab_size=1024, |
|
pad_token_id=1680, |
|
bos_target_id=1681, |
|
eos_target_id=1682, |
|
bos_phone_id=1683, |
|
eos_phone_id=1684, |
|
) |
|
|
|
self.ar_model.load_state_dict(torch.load('valle_ar.bin')) |
|
self.ar_model.eval().to(self.device) |
|
from models.valle_nar import ValleNAR |
|
self.nar_model = ValleNAR( |
|
phone_vocab_size=656, |
|
target_vocab_size=1024, |
|
pad_token_id=1680, |
|
bos_target_id=1681, |
|
eos_target_id=1682, |
|
bos_phone_id=1683, |
|
eos_phone_id=1684, |
|
) |
|
self.nar_model.load_state_dict(torch.load('valle_nar.bin')) |
|
self.nar_model.eval().to(self.device) |
|
|
|
from encodec import EncodecModel |
|
self.codec_encoder = EncodecModel.encodec_model_24khz() |
|
self.codec_encoder.set_target_bandwidth(6.0) |
|
self.codec_encoder.to(self.device) |
|
if use_vocos: |
|
from vocos import Vocos |
|
self.codec_decoder = Vocos.from_pretrained("charactr/vocos-encodec-24khz") |
|
self.codec_decoder.to(self.device) |
|
|
|
|
|
self.use_vocos = use_vocos |
|
|
|
def decode(self, vq_ids): |
|
'''vq_ids.shape: [8, B, T], |
|
returns: [B, 1, T*320]''' |
|
if not self.use_vocos: |
|
return self.codec_encoder.decode([(vq_ids.transpose(0,1), None)]) |
|
else: |
|
features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1)) |
|
bandwidth_id = torch.tensor([2], device=vq_ids.device) |
|
return self.codec_decoder.decode(features, bandwidth_id=bandwidth_id).unsqueeze(0) |
|
|
|
|
|
def forward(self, batch, temperature=1.1): |
|
for k, v in batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
batch[k] = v.to(self.device) |
|
with torch.no_grad(): |
|
|
|
'''Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') |
|
speech: [B, T] |
|
speech_len: [B] |
|
phone_ids: [B, T] |
|
phone_lens: [B] |
|
''' |
|
vq_id = self.codec_encoder.encode(batch['speech'].unsqueeze(1)) |
|
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(0,1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ar_vq_ids = self.ar_model.sample_hf( |
|
batch['phone_ids'], |
|
vq_id[0, :, :225], |
|
top_p=0.9, |
|
top_k=1024, |
|
temperature=1.2, |
|
|
|
) |
|
|
|
recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0)) |
|
torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000) |
|
|
|
|
|
nar_vq_ids = self.nar_model.sample_hf( |
|
phone_ids=batch['phone_ids'], |
|
prompt_ids=vq_id[:,:,:225], |
|
|
|
first_stage_ids=ar_vq_ids, |
|
) |
|
|
|
nar_vq_ids = torch.cat([vq_id[..., :225], nar_vq_ids], dim=-1) |
|
|
|
recovered_audio = self.decode(nar_vq_ids) |
|
|
|
torchaudio.save('recovered_audio_nar.wav', recovered_audio[0].cpu(), 24000) |
|
|
|
|
|
return recovered_audio |
|
|
|
class LibriSpeechDevDataset(torch.utils.data.Dataset): |
|
def __init__(self, data_dir=None, use_vocos=False): |
|
|
|
self.data_dir = '/mnt/petrelfs/hehaorui/jiaqi/LibriSpeech/test-clean/8224/274384' |
|
|
|
self.wav_list = glob.glob(self.data_dir + '/*.flac') + glob.glob(self.data_dir + '/*.wav') |
|
|
|
self.transcript_file = glob.glob(self.data_dir + '/*.txt')[0] |
|
self.transcripts = {} |
|
with open(self.transcript_file, 'r') as f: |
|
for line in f: |
|
line = line.strip().split() |
|
self.transcripts[line[0]] = ' '.join(line[1:]) |
|
|
|
def __len__(self): |
|
return len(self.wav_list) |
|
|
|
def __getitem__(self, idx): |
|
wav_file = self.wav_list[idx] |
|
transcript = self.transcripts[os.path.basename(wav_file)[:-5]] |
|
orig_transcript = transcript |
|
transcript = g2p(transcript, 'en')[1] |
|
transcript = torch.tensor(transcript, dtype=torch.long) |
|
|
|
speech, _ = librosa.load(wav_file, sr=16000) |
|
|
|
speech = librosa.resample(speech, orig_sr=16000, target_sr=24000) |
|
speech = torch.tensor(speech, dtype=torch.float32) |
|
|
|
return { |
|
'speech': speech, |
|
'phone_ids': transcript, |
|
'transcript': orig_transcript, |
|
} |
|
|
|
|
|
dataset = LibriSpeechTestDataset() |
|
inference = ValleInference(use_vocos=False) |
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) |
|
import tqdm |
|
for batch in tqdm.tqdm(dataloader): |
|
|
|
|
|
print(batch['transcript'][0].lower()) |
|
recovered_audio = inference(batch) |
|
|
|
|
|
|
|
|
|
|
|
uid = batch['output_path'][0] |
|
save_path = os.path.join('./inference_retrain', uid) |
|
|
|
|
|
|
|
torchaudio.save(save_path, recovered_audio[0].cpu(), 24000) |
|
print(f'saved to {save_path}') |
|
|
|
|
|
|
|
|
|
|
|
|