|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import json |
|
import random |
|
import re |
|
import tarfile |
|
from subprocess import PIPE, Popen |
|
from urllib.parse import urlparse |
|
|
|
import torch |
|
import torchaudio |
|
import torchaudio.compliance.kaldi as kaldi |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
AUDIO_FORMAT_SETS = set(["flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"]) |
|
|
|
|
|
def url_opener(data): |
|
"""Give url or local file, return file descriptor |
|
Inplace operation. |
|
|
|
Args: |
|
data(Iterable[str]): url or local file list |
|
|
|
Returns: |
|
Iterable[{src, stream}] |
|
""" |
|
for sample in data: |
|
assert "src" in sample |
|
|
|
url = sample["src"] |
|
try: |
|
pr = urlparse(url) |
|
|
|
if pr.scheme == "" or pr.scheme == "file": |
|
stream = open(url, "rb") |
|
|
|
else: |
|
cmd = f"wget -q -O - {url}" |
|
process = Popen(cmd, shell=True, stdout=PIPE) |
|
sample.update(process=process) |
|
stream = process.stdout |
|
sample.update(stream=stream) |
|
yield sample |
|
except Exception as ex: |
|
logging.warning("Failed to open {}".format(url)) |
|
|
|
|
|
def tar_file_and_group(data): |
|
"""Expand a stream of open tar files into a stream of tar file contents. |
|
And groups the file with same prefix |
|
|
|
Args: |
|
data: Iterable[{src, stream}] |
|
|
|
Returns: |
|
Iterable[{key, wav, txt, sample_rate}] |
|
""" |
|
for sample in data: |
|
assert "stream" in sample |
|
stream = tarfile.open(fileobj=sample["stream"], mode="r|*") |
|
prev_prefix = None |
|
example = {} |
|
valid = True |
|
for tarinfo in stream: |
|
name = tarinfo.name |
|
pos = name.rfind(".") |
|
assert pos > 0 |
|
prefix, postfix = name[:pos], name[pos + 1 :] |
|
if prev_prefix is not None and prefix != prev_prefix: |
|
example["key"] = prev_prefix |
|
if valid: |
|
yield example |
|
example = {} |
|
valid = True |
|
with stream.extractfile(tarinfo) as file_obj: |
|
try: |
|
if postfix == "txt": |
|
example["txt"] = file_obj.read().decode("utf8").strip() |
|
elif postfix in AUDIO_FORMAT_SETS: |
|
waveform, sample_rate = torchaudio.load(file_obj) |
|
example["wav"] = waveform |
|
example["sample_rate"] = sample_rate |
|
else: |
|
example[postfix] = file_obj.read() |
|
except Exception as ex: |
|
valid = False |
|
logging.warning("error to parse {}".format(name)) |
|
prev_prefix = prefix |
|
if prev_prefix is not None: |
|
example["key"] = prev_prefix |
|
yield example |
|
stream.close() |
|
if "process" in sample: |
|
sample["process"].communicate() |
|
sample["stream"].close() |
|
|
|
|
|
def parse_raw(data): |
|
"""Parse key/wav/txt from json line |
|
|
|
Args: |
|
data: Iterable[str], str is a json line has key/wav/txt |
|
|
|
Returns: |
|
Iterable[{key, wav, txt, sample_rate}] |
|
""" |
|
for sample in data: |
|
assert "src" in sample |
|
json_line = sample["src"] |
|
obj = json.loads(json_line) |
|
assert "key" in obj |
|
assert "wav" in obj |
|
assert "txt" in obj |
|
key = obj["key"] |
|
wav_file = obj["wav"] |
|
txt = obj["txt"] |
|
try: |
|
if "start" in obj: |
|
assert "end" in obj |
|
sample_rate = torchaudio.backend.sox_io_backend.info( |
|
wav_file |
|
).sample_rate |
|
start_frame = int(obj["start"] * sample_rate) |
|
end_frame = int(obj["end"] * sample_rate) |
|
waveform, _ = torchaudio.backend.sox_io_backend.load( |
|
filepath=wav_file, |
|
num_frames=end_frame - start_frame, |
|
frame_offset=start_frame, |
|
) |
|
else: |
|
waveform, sample_rate = torchaudio.load(wav_file) |
|
example = dict(key=key, txt=txt, wav=waveform, sample_rate=sample_rate) |
|
yield example |
|
except Exception as ex: |
|
logging.warning("Failed to read {}".format(wav_file)) |
|
|
|
|
|
def filter( |
|
data, |
|
max_length=10240, |
|
min_length=10, |
|
token_max_length=200, |
|
token_min_length=1, |
|
min_output_input_ratio=0.0005, |
|
max_output_input_ratio=1, |
|
): |
|
"""Filter sample according to feature and label length |
|
Inplace operation. |
|
|
|
Args:: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
max_length: drop utterance which is greater than max_length(10ms) |
|
min_length: drop utterance which is less than min_length(10ms) |
|
token_max_length: drop utterance which is greater than |
|
token_max_length, especially when use char unit for |
|
english modeling |
|
token_min_length: drop utterance which is |
|
less than token_max_length |
|
min_output_input_ratio: minimal ration of |
|
token_length / feats_length(10ms) |
|
max_output_input_ratio: maximum ration of |
|
token_length / feats_length(10ms) |
|
|
|
Returns: |
|
Iterable[{key, wav, label, sample_rate}] |
|
""" |
|
for sample in data: |
|
assert "sample_rate" in sample |
|
assert "wav" in sample |
|
assert "label" in sample |
|
|
|
num_frames = sample["wav"].size(1) / sample["sample_rate"] * 100 |
|
if num_frames < min_length: |
|
continue |
|
if num_frames > max_length: |
|
continue |
|
if len(sample["label"]) < token_min_length: |
|
continue |
|
if len(sample["label"]) > token_max_length: |
|
continue |
|
if num_frames != 0: |
|
if len(sample["label"]) / num_frames < min_output_input_ratio: |
|
continue |
|
if len(sample["label"]) / num_frames > max_output_input_ratio: |
|
continue |
|
yield sample |
|
|
|
|
|
def resample(data, resample_rate=16000): |
|
"""Resample data. |
|
Inplace operation. |
|
|
|
Args: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
resample_rate: target resample rate |
|
|
|
Returns: |
|
Iterable[{key, wav, label, sample_rate}] |
|
""" |
|
print("resample...") |
|
for sample in data: |
|
assert "sample_rate" in sample |
|
assert "wav" in sample |
|
sample_rate = sample["sample_rate"] |
|
print("sample_rate: ", sample_rate) |
|
print("resample_rate: ", resample_rate) |
|
waveform = sample["wav"] |
|
if sample_rate != resample_rate: |
|
sample["sample_rate"] = resample_rate |
|
sample["wav"] = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, new_freq=resample_rate |
|
)(waveform) |
|
yield sample |
|
|
|
|
|
def speed_perturb(data, speeds=None): |
|
"""Apply speed perturb to the data. |
|
Inplace operation. |
|
|
|
Args: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
speeds(List[float]): optional speed |
|
|
|
Returns: |
|
Iterable[{key, wav, label, sample_rate}] |
|
""" |
|
if speeds is None: |
|
speeds = [0.9, 1.0, 1.1] |
|
for sample in data: |
|
assert "sample_rate" in sample |
|
assert "wav" in sample |
|
sample_rate = sample["sample_rate"] |
|
waveform = sample["wav"] |
|
speed = random.choice(speeds) |
|
if speed != 1.0: |
|
wav, _ = torchaudio.sox_effects.apply_effects_tensor( |
|
waveform, |
|
sample_rate, |
|
[["speed", str(speed)], ["rate", str(sample_rate)]], |
|
) |
|
sample["wav"] = wav |
|
|
|
yield sample |
|
|
|
|
|
def compute_fbank(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0): |
|
"""Extract fbank |
|
|
|
Args: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
|
|
Returns: |
|
Iterable[{key, feat, label}] |
|
""" |
|
for sample in data: |
|
assert "sample_rate" in sample |
|
assert "wav" in sample |
|
assert "key" in sample |
|
assert "label" in sample |
|
sample_rate = sample["sample_rate"] |
|
waveform = sample["wav"] |
|
waveform = waveform * (1 << 15) |
|
|
|
mat = kaldi.fbank( |
|
waveform, |
|
num_mel_bins=num_mel_bins, |
|
frame_length=frame_length, |
|
frame_shift=frame_shift, |
|
dither=dither, |
|
energy_floor=0.0, |
|
sample_frequency=sample_rate, |
|
) |
|
yield dict(key=sample["key"], label=sample["label"], feat=mat) |
|
|
|
|
|
def compute_mfcc( |
|
data, |
|
num_mel_bins=23, |
|
frame_length=25, |
|
frame_shift=10, |
|
dither=0.0, |
|
num_ceps=40, |
|
high_freq=0.0, |
|
low_freq=20.0, |
|
): |
|
"""Extract mfcc |
|
|
|
Args: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
|
|
Returns: |
|
Iterable[{key, feat, label}] |
|
""" |
|
for sample in data: |
|
assert "sample_rate" in sample |
|
assert "wav" in sample |
|
assert "key" in sample |
|
assert "label" in sample |
|
sample_rate = sample["sample_rate"] |
|
waveform = sample["wav"] |
|
waveform = waveform * (1 << 15) |
|
|
|
mat = kaldi.mfcc( |
|
waveform, |
|
num_mel_bins=num_mel_bins, |
|
frame_length=frame_length, |
|
frame_shift=frame_shift, |
|
dither=dither, |
|
num_ceps=num_ceps, |
|
high_freq=high_freq, |
|
low_freq=low_freq, |
|
sample_frequency=sample_rate, |
|
) |
|
yield dict(key=sample["key"], label=sample["label"], feat=mat) |
|
|
|
|
|
def __tokenize_by_bpe_model(sp, txt): |
|
tokens = [] |
|
|
|
|
|
pattern = re.compile(r"([\u4e00-\u9fff])") |
|
|
|
|
|
|
|
chars = pattern.split(txt.upper()) |
|
mix_chars = [w for w in chars if len(w.strip()) > 0] |
|
for ch_or_w in mix_chars: |
|
|
|
if pattern.fullmatch(ch_or_w) is not None: |
|
tokens.append(ch_or_w) |
|
|
|
|
|
else: |
|
for p in sp.encode_as_pieces(ch_or_w): |
|
tokens.append(p) |
|
|
|
return tokens |
|
|
|
|
|
def tokenize( |
|
data, symbol_table, bpe_model=None, non_lang_syms=None, split_with_space=False |
|
): |
|
"""Decode text to chars or BPE |
|
Inplace operation |
|
|
|
Args: |
|
data: Iterable[{key, wav, txt, sample_rate}] |
|
|
|
Returns: |
|
Iterable[{key, wav, txt, tokens, label, sample_rate}] |
|
""" |
|
if non_lang_syms is not None: |
|
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") |
|
else: |
|
non_lang_syms = {} |
|
non_lang_syms_pattern = None |
|
|
|
if bpe_model is not None: |
|
import sentencepiece as spm |
|
|
|
sp = spm.SentencePieceProcessor() |
|
sp.load(bpe_model) |
|
else: |
|
sp = None |
|
|
|
for sample in data: |
|
assert "txt" in sample |
|
txt = sample["txt"].strip() |
|
if non_lang_syms_pattern is not None: |
|
parts = non_lang_syms_pattern.split(txt.upper()) |
|
parts = [w for w in parts if len(w.strip()) > 0] |
|
else: |
|
parts = [txt] |
|
|
|
label = [] |
|
tokens = [] |
|
for part in parts: |
|
if part in non_lang_syms: |
|
tokens.append(part) |
|
else: |
|
if bpe_model is not None: |
|
tokens.extend(__tokenize_by_bpe_model(sp, part)) |
|
else: |
|
if split_with_space: |
|
part = part.split(" ") |
|
for ch in part: |
|
if ch == " ": |
|
ch = "▁" |
|
tokens.append(ch) |
|
|
|
for ch in tokens: |
|
if ch in symbol_table: |
|
label.append(symbol_table[ch]) |
|
elif "<unk>" in symbol_table: |
|
label.append(symbol_table["<unk>"]) |
|
|
|
sample["tokens"] = tokens |
|
sample["label"] = label |
|
yield sample |
|
|
|
|
|
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): |
|
"""Do spec augmentation |
|
Inplace operation |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
num_t_mask: number of time mask to apply |
|
num_f_mask: number of freq mask to apply |
|
max_t: max width of time mask |
|
max_f: max width of freq mask |
|
max_w: max width of time warp |
|
|
|
Returns |
|
Iterable[{key, feat, label}] |
|
""" |
|
for sample in data: |
|
assert "feat" in sample |
|
x = sample["feat"] |
|
assert isinstance(x, torch.Tensor) |
|
y = x.clone().detach() |
|
max_frames = y.size(0) |
|
max_freq = y.size(1) |
|
|
|
for i in range(num_t_mask): |
|
start = random.randint(0, max_frames - 1) |
|
length = random.randint(1, max_t) |
|
end = min(max_frames, start + length) |
|
y[start:end, :] = 0 |
|
|
|
for i in range(num_f_mask): |
|
start = random.randint(0, max_freq - 1) |
|
length = random.randint(1, max_f) |
|
end = min(max_freq, start + length) |
|
y[:, start:end] = 0 |
|
sample["feat"] = y |
|
yield sample |
|
|
|
|
|
def spec_sub(data, max_t=20, num_t_sub=3): |
|
"""Do spec substitute |
|
Inplace operation |
|
ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642] |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
max_t: max width of time substitute |
|
num_t_sub: number of time substitute to apply |
|
|
|
Returns |
|
Iterable[{key, feat, label}] |
|
""" |
|
for sample in data: |
|
assert "feat" in sample |
|
x = sample["feat"] |
|
assert isinstance(x, torch.Tensor) |
|
y = x.clone().detach() |
|
max_frames = y.size(0) |
|
for i in range(num_t_sub): |
|
start = random.randint(0, max_frames - 1) |
|
length = random.randint(1, max_t) |
|
end = min(max_frames, start + length) |
|
|
|
pos = random.randint(0, start) |
|
y[start:end, :] = x[start - pos : end - pos, :] |
|
sample["feat"] = y |
|
yield sample |
|
|
|
|
|
def spec_trim(data, max_t=20): |
|
"""Trim tailing frames. Inplace operation. |
|
ref: TrimTail [https://arxiv.org/abs/2211.00522] |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
max_t: max width of length trimming |
|
|
|
Returns |
|
Iterable[{key, feat, label}] |
|
""" |
|
for sample in data: |
|
assert "feat" in sample |
|
x = sample["feat"] |
|
assert isinstance(x, torch.Tensor) |
|
max_frames = x.size(0) |
|
length = random.randint(1, max_t) |
|
if length < max_frames / 2: |
|
y = x.clone().detach()[: max_frames - length] |
|
sample["feat"] = y |
|
yield sample |
|
|
|
|
|
def shuffle(data, shuffle_size=10000): |
|
"""Local shuffle the data |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
shuffle_size: buffer size for shuffle |
|
|
|
Returns: |
|
Iterable[{key, feat, label}] |
|
""" |
|
buf = [] |
|
for sample in data: |
|
buf.append(sample) |
|
if len(buf) >= shuffle_size: |
|
random.shuffle(buf) |
|
for x in buf: |
|
yield x |
|
buf = [] |
|
|
|
random.shuffle(buf) |
|
for x in buf: |
|
yield x |
|
|
|
|
|
def sort(data, sort_size=500): |
|
"""Sort the data by feature length. |
|
Sort is used after shuffle and before batch, so we can group |
|
utts with similar lengths into a batch, and `sort_size` should |
|
be less than `shuffle_size` |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
sort_size: buffer size for sort |
|
|
|
Returns: |
|
Iterable[{key, feat, label}] |
|
""" |
|
|
|
buf = [] |
|
for sample in data: |
|
buf.append(sample) |
|
if len(buf) >= sort_size: |
|
buf.sort(key=lambda x: x["feat"].size(0)) |
|
for x in buf: |
|
yield x |
|
buf = [] |
|
|
|
buf.sort(key=lambda x: x["feat"].size(0)) |
|
for x in buf: |
|
yield x |
|
|
|
|
|
def static_batch(data, batch_size=16): |
|
"""Static batch the data by `batch_size` |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
batch_size: batch size |
|
|
|
Returns: |
|
Iterable[List[{key, feat, label}]] |
|
""" |
|
buf = [] |
|
for sample in data: |
|
buf.append(sample) |
|
if len(buf) >= batch_size: |
|
yield buf |
|
buf = [] |
|
if len(buf) > 0: |
|
yield buf |
|
|
|
|
|
def dynamic_batch(data, max_frames_in_batch=12000): |
|
"""Dynamic batch the data until the total frames in batch |
|
reach `max_frames_in_batch` |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
max_frames_in_batch: max_frames in one batch |
|
|
|
Returns: |
|
Iterable[List[{key, feat, label}]] |
|
""" |
|
buf = [] |
|
longest_frames = 0 |
|
for sample in data: |
|
assert "feat" in sample |
|
assert isinstance(sample["feat"], torch.Tensor) |
|
new_sample_frames = sample["feat"].size(0) |
|
longest_frames = max(longest_frames, new_sample_frames) |
|
frames_after_padding = longest_frames * (len(buf) + 1) |
|
if frames_after_padding > max_frames_in_batch: |
|
yield buf |
|
buf = [sample] |
|
longest_frames = new_sample_frames |
|
else: |
|
buf.append(sample) |
|
if len(buf) > 0: |
|
yield buf |
|
|
|
|
|
def batch(data, batch_type="static", batch_size=16, max_frames_in_batch=12000): |
|
"""Wrapper for static/dynamic batch""" |
|
if batch_type == "static": |
|
return static_batch(data, batch_size) |
|
elif batch_type == "dynamic": |
|
return dynamic_batch(data, max_frames_in_batch) |
|
else: |
|
logging.fatal("Unsupported batch type {}".format(batch_type)) |
|
|
|
|
|
def padding(data): |
|
"""Padding the data into training data |
|
|
|
Args: |
|
data: Iterable[List[{key, feat, label}]] |
|
|
|
Returns: |
|
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] |
|
""" |
|
for sample in data: |
|
assert isinstance(sample, list) |
|
feats_length = torch.tensor( |
|
[x["feat"].size(0) for x in sample], dtype=torch.int32 |
|
) |
|
order = torch.argsort(feats_length, descending=True) |
|
feats_lengths = torch.tensor( |
|
[sample[i]["feat"].size(0) for i in order], dtype=torch.int32 |
|
) |
|
sorted_feats = [sample[i]["feat"] for i in order] |
|
sorted_keys = [sample[i]["key"] for i in order] |
|
sorted_labels = [ |
|
torch.tensor(sample[i]["label"], dtype=torch.int64) for i in order |
|
] |
|
label_lengths = torch.tensor( |
|
[x.size(0) for x in sorted_labels], dtype=torch.int32 |
|
) |
|
|
|
padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0) |
|
padding_labels = pad_sequence(sorted_labels, batch_first=True, padding_value=-1) |
|
|
|
yield (sorted_keys, padded_feats, padding_labels, feats_lengths, label_lengths) |
|
|