|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function |
|
|
|
import argparse |
|
import copy |
|
import logging |
|
import os |
|
import sys |
|
|
|
import torch |
|
import yaml |
|
from torch.utils.data import DataLoader |
|
|
|
from wenet.dataset.dataset import Dataset |
|
from wenet.utils.common import IGNORE_ID |
|
from wenet.utils.file_utils import read_symbol_table |
|
from wenet.utils.config import override_config |
|
|
|
import onnxruntime as rt |
|
import multiprocessing |
|
import numpy as np |
|
|
|
try: |
|
from swig_decoders import ( |
|
map_batch, |
|
ctc_beam_search_decoder_batch, |
|
TrieVector, |
|
PathTrie, |
|
) |
|
except ImportError: |
|
print( |
|
"Please install ctc decoders first by refering to\n" |
|
+ "https://github.com/Slyne/ctc_decoder.git" |
|
) |
|
sys.exit(1) |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description="recognize with your model") |
|
parser.add_argument("--config", required=True, help="config file") |
|
parser.add_argument("--test_data", required=True, help="test data file") |
|
parser.add_argument( |
|
"--data_type", |
|
default="raw", |
|
choices=["raw", "shard"], |
|
help="train and cv data type", |
|
) |
|
parser.add_argument( |
|
"--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu" |
|
) |
|
parser.add_argument("--dict", required=True, help="dict file") |
|
parser.add_argument("--encoder_onnx", required=True, help="encoder onnx file") |
|
parser.add_argument("--decoder_onnx", required=True, help="decoder onnx file") |
|
parser.add_argument("--result_file", required=True, help="asr result file") |
|
parser.add_argument("--batch_size", type=int, default=32, help="asr result file") |
|
parser.add_argument( |
|
"--mode", |
|
choices=["ctc_greedy_search", "ctc_prefix_beam_search", "attention_rescoring"], |
|
default="attention_rescoring", |
|
help="decoding mode", |
|
) |
|
parser.add_argument( |
|
"--bpe_model", default=None, type=str, help="bpe model for english part" |
|
) |
|
parser.add_argument( |
|
"--override_config", action="append", default=[], help="override yaml config" |
|
) |
|
parser.add_argument( |
|
"--fp16", |
|
action="store_true", |
|
help="whether to export fp16 model, default false", |
|
) |
|
args = parser.parse_args() |
|
print(args) |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
logging.basicConfig( |
|
level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" |
|
) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
|
|
|
with open(args.config, "r") as fin: |
|
configs = yaml.load(fin, Loader=yaml.FullLoader) |
|
if len(args.override_config) > 0: |
|
configs = override_config(configs, args.override_config) |
|
|
|
reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) |
|
symbol_table = read_symbol_table(args.dict) |
|
test_conf = copy.deepcopy(configs["dataset_conf"]) |
|
test_conf["filter_conf"]["max_length"] = 102400 |
|
test_conf["filter_conf"]["min_length"] = 0 |
|
test_conf["filter_conf"]["token_max_length"] = 102400 |
|
test_conf["filter_conf"]["token_min_length"] = 0 |
|
test_conf["filter_conf"]["max_output_input_ratio"] = 102400 |
|
test_conf["filter_conf"]["min_output_input_ratio"] = 0 |
|
test_conf["speed_perturb"] = False |
|
test_conf["spec_aug"] = False |
|
test_conf["spec_sub"] = False |
|
test_conf["spec_trim"] = False |
|
test_conf["shuffle"] = False |
|
test_conf["sort"] = False |
|
test_conf["fbank_conf"]["dither"] = 0.0 |
|
test_conf["batch_conf"]["batch_type"] = "static" |
|
test_conf["batch_conf"]["batch_size"] = args.batch_size |
|
|
|
test_dataset = Dataset( |
|
args.data_type, |
|
args.test_data, |
|
symbol_table, |
|
test_conf, |
|
args.bpe_model, |
|
partition=False, |
|
) |
|
|
|
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) |
|
|
|
|
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available() |
|
if use_cuda: |
|
EP_list = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
else: |
|
EP_list = ["CPUExecutionProvider"] |
|
|
|
encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list) |
|
decoder_ort_session = None |
|
if args.mode == "attention_rescoring": |
|
decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list) |
|
|
|
|
|
vocabulary = [] |
|
char_dict = {} |
|
with open(args.dict, "r") as fin: |
|
for line in fin: |
|
arr = line.strip().split() |
|
assert len(arr) == 2 |
|
char_dict[int(arr[1])] = arr[0] |
|
vocabulary.append(arr[0]) |
|
eos = sos = len(char_dict) - 1 |
|
with torch.no_grad(), open(args.result_file, "w") as fout: |
|
for _, batch in enumerate(test_data_loader): |
|
keys, feats, _, feats_lengths, _ = batch |
|
feats, feats_lengths = feats.numpy(), feats_lengths.numpy() |
|
if args.fp16: |
|
feats = feats.astype(np.float16) |
|
ort_inputs = { |
|
encoder_ort_session.get_inputs()[0].name: feats, |
|
encoder_ort_session.get_inputs()[1].name: feats_lengths, |
|
} |
|
ort_outs = encoder_ort_session.run(None, ort_inputs) |
|
( |
|
encoder_out, |
|
encoder_out_lens, |
|
ctc_log_probs, |
|
beam_log_probs, |
|
beam_log_probs_idx, |
|
) = ort_outs |
|
beam_size = beam_log_probs.shape[-1] |
|
batch_size = beam_log_probs.shape[0] |
|
num_processes = min(multiprocessing.cpu_count(), batch_size) |
|
if args.mode == "ctc_greedy_search": |
|
if beam_size != 1: |
|
log_probs_idx = beam_log_probs_idx[:, :, 0] |
|
batch_sents = [] |
|
for idx, seq in enumerate(log_probs_idx): |
|
batch_sents.append(seq[0 : encoder_out_lens[idx]].tolist()) |
|
hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) |
|
elif args.mode in ("ctc_prefix_beam_search", "attention_rescoring"): |
|
batch_log_probs_seq_list = beam_log_probs.tolist() |
|
batch_log_probs_idx_list = beam_log_probs_idx.tolist() |
|
batch_len_list = encoder_out_lens.tolist() |
|
batch_log_probs_seq = [] |
|
batch_log_probs_ids = [] |
|
batch_start = [] |
|
batch_root = TrieVector() |
|
root_dict = {} |
|
for i in range(len(batch_len_list)): |
|
num_sent = batch_len_list[i] |
|
batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent]) |
|
batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent]) |
|
root_dict[i] = PathTrie() |
|
batch_root.append(root_dict[i]) |
|
batch_start.append(True) |
|
score_hyps = ctc_beam_search_decoder_batch( |
|
batch_log_probs_seq, |
|
batch_log_probs_ids, |
|
batch_root, |
|
batch_start, |
|
beam_size, |
|
num_processes, |
|
0, |
|
-2, |
|
0.99999, |
|
) |
|
if args.mode == "ctc_prefix_beam_search": |
|
hyps = [] |
|
for cand_hyps in score_hyps: |
|
hyps.append(cand_hyps[0][1]) |
|
hyps = map_batch(hyps, vocabulary, num_processes, False, 0) |
|
if args.mode == "attention_rescoring": |
|
ctc_score, all_hyps = [], [] |
|
max_len = 0 |
|
for hyps in score_hyps: |
|
cur_len = len(hyps) |
|
if len(hyps) < beam_size: |
|
hyps += (beam_size - cur_len) * [(-float("INF"), (0,))] |
|
cur_ctc_score = [] |
|
for hyp in hyps: |
|
cur_ctc_score.append(hyp[0]) |
|
all_hyps.append(list(hyp[1])) |
|
if len(hyp[1]) > max_len: |
|
max_len = len(hyp[1]) |
|
ctc_score.append(cur_ctc_score) |
|
if args.fp16: |
|
ctc_score = np.array(ctc_score, dtype=np.float16) |
|
else: |
|
ctc_score = np.array(ctc_score, dtype=np.float32) |
|
hyps_pad_sos_eos = ( |
|
np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) |
|
* IGNORE_ID |
|
) |
|
r_hyps_pad_sos_eos = ( |
|
np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) |
|
* IGNORE_ID |
|
) |
|
hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) |
|
k = 0 |
|
for i in range(batch_size): |
|
for j in range(beam_size): |
|
cand = all_hyps[k] |
|
l = len(cand) + 2 |
|
hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] |
|
r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [eos] |
|
hyps_lens_sos[i][j] = len(cand) + 1 |
|
k += 1 |
|
decoder_ort_inputs = { |
|
decoder_ort_session.get_inputs()[0].name: encoder_out, |
|
decoder_ort_session.get_inputs()[1].name: encoder_out_lens, |
|
decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos, |
|
decoder_ort_session.get_inputs()[3].name: hyps_lens_sos, |
|
decoder_ort_session.get_inputs()[-1].name: ctc_score, |
|
} |
|
if reverse_weight > 0: |
|
r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs()[4].name |
|
decoder_ort_inputs[r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos |
|
best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0] |
|
best_sents = [] |
|
k = 0 |
|
for idx in best_index: |
|
cur_best_sent = all_hyps[k : k + beam_size][idx] |
|
best_sents.append(cur_best_sent) |
|
k += beam_size |
|
hyps = map_batch(best_sents, vocabulary, num_processes) |
|
|
|
for i, key in enumerate(keys): |
|
content = hyps[i] |
|
logging.info("{} {}".format(key, content)) |
|
fout.write("{} {}\n".format(key, content)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|