|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function |
|
|
|
import argparse |
|
import copy |
|
import logging |
|
import os |
|
import sys |
|
|
|
import torch |
|
import yaml |
|
from torch.utils.data import DataLoader |
|
from textgrid import TextGrid, IntervalTier |
|
|
|
from wenet.dataset.dataset import Dataset |
|
from wenet.utils.checkpoint import load_checkpoint |
|
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols |
|
from wenet.utils.ctc_util import forced_align |
|
from wenet.utils.common import get_subsample |
|
from wenet.utils.init_model import init_model |
|
|
|
|
|
def generator_textgrid(maxtime, lines, output): |
|
|
|
interval = maxtime / (len(lines) + 1) |
|
margin = 0.0001 |
|
|
|
tg = TextGrid(maxTime=maxtime) |
|
linetier = IntervalTier(name="line", maxTime=maxtime) |
|
|
|
i = 0 |
|
for l in lines: |
|
s, e, w = l.split() |
|
linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w) |
|
|
|
tg.append(linetier) |
|
print("successfully generator {}".format(output)) |
|
tg.write(output) |
|
|
|
|
|
def get_frames_timestamp(alignment): |
|
|
|
|
|
timestamp = [] |
|
|
|
start = 0 |
|
end = 0 |
|
while end < len(alignment): |
|
while end < len(alignment) and alignment[end] == 0: |
|
end += 1 |
|
if end == len(alignment): |
|
timestamp[-1] += alignment[start:] |
|
break |
|
end += 1 |
|
while end < len(alignment) and alignment[end - 1] == alignment[end]: |
|
end += 1 |
|
timestamp.append(alignment[start:end]) |
|
start = end |
|
return timestamp |
|
|
|
|
|
def get_labformat(timestamp, subsample): |
|
begin = 0 |
|
duration = 0 |
|
labformat = [] |
|
for idx, t in enumerate(timestamp): |
|
|
|
subsample = get_subsample(configs) |
|
|
|
duration = len(t) * 0.01 * subsample |
|
if idx < len(timestamp) - 1: |
|
print("{:.2f} {:.2f} {}".format(begin, begin + duration, char_dict[t[-1]])) |
|
labformat.append( |
|
"{:.2f} {:.2f} {}\n".format(begin, begin + duration, char_dict[t[-1]]) |
|
) |
|
else: |
|
non_blank = 0 |
|
for i in t: |
|
if i != 0: |
|
token = i |
|
break |
|
print("{:.2f} {:.2f} {}".format(begin, begin + duration, char_dict[token])) |
|
labformat.append( |
|
"{:.2f} {:.2f} {}\n".format(begin, begin + duration, char_dict[token]) |
|
) |
|
begin = begin + duration |
|
return labformat |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="use ctc to generate alignment") |
|
parser.add_argument("--config", required=True, help="config file") |
|
parser.add_argument("--input_file", required=True, help="format data file") |
|
parser.add_argument( |
|
"--data_type", |
|
default="raw", |
|
choices=["raw", "shard"], |
|
help="train and cv data type", |
|
) |
|
parser.add_argument( |
|
"--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu" |
|
) |
|
parser.add_argument("--checkpoint", required=True, help="checkpoint model") |
|
parser.add_argument("--dict", required=True, help="dict file") |
|
parser.add_argument( |
|
"--non_lang_syms", help="non-linguistic symbol file. One symbol per line." |
|
) |
|
parser.add_argument("--result_file", required=True, help="alignment result file") |
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size") |
|
parser.add_argument( |
|
"--gen_praat", action="store_true", help="convert alignment to a praat format" |
|
) |
|
parser.add_argument( |
|
"--bpe_model", default=None, type=str, help="bpe model for english part" |
|
) |
|
|
|
args = parser.parse_args() |
|
print(args) |
|
logging.basicConfig( |
|
level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" |
|
) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
|
|
|
if args.batch_size > 1: |
|
logging.fatal("alignment mode must be running with batch_size == 1") |
|
sys.exit(1) |
|
|
|
with open(args.config, "r") as fin: |
|
configs = yaml.load(fin, Loader=yaml.FullLoader) |
|
|
|
|
|
char_dict = {} |
|
with open(args.dict, "r") as fin: |
|
for line in fin: |
|
arr = line.strip().split() |
|
assert len(arr) == 2 |
|
char_dict[int(arr[1])] = arr[0] |
|
eos = len(char_dict) - 1 |
|
|
|
symbol_table = read_symbol_table(args.dict) |
|
|
|
|
|
ali_conf = copy.deepcopy(configs["dataset_conf"]) |
|
|
|
ali_conf["filter_conf"]["max_length"] = 102400 |
|
ali_conf["filter_conf"]["min_length"] = 0 |
|
ali_conf["filter_conf"]["token_max_length"] = 102400 |
|
ali_conf["filter_conf"]["token_min_length"] = 0 |
|
ali_conf["filter_conf"]["max_output_input_ratio"] = 102400 |
|
ali_conf["filter_conf"]["min_output_input_ratio"] = 0 |
|
ali_conf["speed_perturb"] = False |
|
ali_conf["spec_aug"] = False |
|
ali_conf["shuffle"] = False |
|
ali_conf["sort"] = False |
|
ali_conf["fbank_conf"]["dither"] = 0.0 |
|
ali_conf["batch_conf"]["batch_type"] = "static" |
|
ali_conf["batch_conf"]["batch_size"] = args.batch_size |
|
non_lang_syms = read_non_lang_symbols(args.non_lang_syms) |
|
|
|
ali_dataset = Dataset( |
|
args.data_type, |
|
args.input_file, |
|
symbol_table, |
|
ali_conf, |
|
args.bpe_model, |
|
non_lang_syms, |
|
partition=False, |
|
) |
|
|
|
ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0) |
|
|
|
|
|
model = init_model(configs) |
|
|
|
load_checkpoint(model, args.checkpoint) |
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available() |
|
device = torch.device("cuda" if use_cuda else "cpu") |
|
model = model.to(device) |
|
|
|
model.eval() |
|
with torch.no_grad(), open(args.result_file, "w", encoding="utf-8") as fout: |
|
for batch_idx, batch in enumerate(ali_data_loader): |
|
print("#" * 80) |
|
key, feat, target, feats_length, target_length = batch |
|
print(key) |
|
|
|
feat = feat.to(device) |
|
target = target.to(device) |
|
feats_length = feats_length.to(device) |
|
target_length = target_length.to(device) |
|
|
|
|
|
encoder_out, encoder_mask = model._forward_encoder( |
|
feat, feats_length |
|
) |
|
maxlen = encoder_out.size(1) |
|
ctc_probs = model.ctc.log_softmax(encoder_out) |
|
|
|
ctc_probs = ctc_probs.squeeze(0) |
|
target = target.squeeze(0) |
|
alignment = forced_align(ctc_probs, target) |
|
print(alignment) |
|
fout.write("{} {}\n".format(key[0], alignment)) |
|
|
|
if args.gen_praat: |
|
timestamp = get_frames_timestamp(alignment) |
|
print(timestamp) |
|
subsample = get_subsample(configs) |
|
labformat = get_labformat(timestamp, subsample) |
|
|
|
lab_path = os.path.join( |
|
os.path.dirname(args.result_file), key[0] + ".lab" |
|
) |
|
with open(lab_path, "w", encoding="utf-8") as f: |
|
f.writelines(labformat) |
|
|
|
textgrid_path = os.path.join( |
|
os.path.dirname(args.result_file), key[0] + ".TextGrid" |
|
) |
|
generator_textgrid( |
|
maxtime=(len(alignment) + 1) * 0.01 * subsample, |
|
lines=labformat, |
|
output=textgrid_path, |
|
) |
|
|