|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function |
|
|
|
import os |
|
import sys |
|
import copy |
|
import math |
|
import yaml |
|
import logging |
|
from typing import Tuple |
|
|
|
import torch |
|
import numpy as np |
|
|
|
from wenet.transformer.embedding import NoPositionalEncoding |
|
from wenet.utils.checkpoint import load_checkpoint |
|
from wenet.utils.init_model import init_model |
|
from wenet.bin.export_onnx_cpu import get_args, to_numpy, print_input_output_info |
|
|
|
|
|
try: |
|
import onnx |
|
import onnxruntime |
|
except ImportError: |
|
print("Please install onnx and onnxruntime!") |
|
sys.exit(1) |
|
|
|
|
|
logger = logging.getLogger(__file__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
class BPULayerNorm(torch.nn.Module): |
|
"""Refactor torch.nn.LayerNorm to meet 4-D dataflow.""" |
|
|
|
def __init__(self, module, chunk_size=8, run_on_bpu=False): |
|
super().__init__() |
|
original = copy.deepcopy(module) |
|
self.hidden = module.weight.size(0) |
|
self.chunk_size = chunk_size |
|
self.run_on_bpu = run_on_bpu |
|
|
|
if self.run_on_bpu: |
|
self.weight = torch.nn.Parameter( |
|
module.weight.reshape(1, self.hidden, 1, 1).repeat(1, 1, 1, chunk_size) |
|
) |
|
self.bias = torch.nn.Parameter( |
|
module.bias.reshape(1, self.hidden, 1, 1).repeat(1, 1, 1, chunk_size) |
|
) |
|
self.negtive = torch.nn.Parameter( |
|
torch.ones((1, self.hidden, 1, chunk_size)) * -1.0 |
|
) |
|
self.eps = torch.nn.Parameter( |
|
torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps |
|
) |
|
self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) |
|
self.mean_conv_1.weight = torch.nn.Parameter( |
|
torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden) |
|
) |
|
self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) |
|
self.mean_conv_2.weight = torch.nn.Parameter( |
|
torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden) |
|
) |
|
else: |
|
self.norm = module |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
random_data = torch.randn(1, self.chunk_size, self.hidden) |
|
orig_out = module(random_data) |
|
new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2)) |
|
np.testing.assert_allclose( |
|
to_numpy(orig_out), |
|
to_numpy(new_out.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.run_on_bpu: |
|
u = self.mean_conv_1(x) |
|
numerator = x + u * self.negtive |
|
s = torch.pow(numerator, 2) |
|
s = self.mean_conv_2(s) |
|
denominator = torch.sqrt(s + self.eps) |
|
x = torch.div(numerator, denominator) |
|
x = x * self.weight + self.bias |
|
else: |
|
x = x.squeeze(2).transpose(1, 2).contiguous() |
|
x = self.norm(x) |
|
x = x.transpose(1, 2).contiguous().unsqueeze(2) |
|
return x |
|
|
|
|
|
class BPUIdentity(torch.nn.Module): |
|
"""Refactor torch.nn.Identity(). |
|
For inserting BPU node whose input == output. |
|
""" |
|
|
|
def __init__(self, channels): |
|
super().__init__() |
|
self.channels = channels |
|
self.identity_conv = torch.nn.Conv2d( |
|
channels, channels, 1, groups=channels, bias=False |
|
) |
|
torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels) |
|
|
|
self.check_equal() |
|
|
|
def check_equal(self): |
|
random_data = torch.randn(1, self.channels, 1, 10) |
|
result = self.forward(random_data) |
|
np.testing.assert_allclose( |
|
to_numpy(random_data), to_numpy(result), rtol=1e-02, atol=1e-03 |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Identity with 4-D dataflow, input == output. |
|
Args: |
|
x (torch.Tensor): (batch, in_channel, 1, time) |
|
|
|
Returns: |
|
(torch.Tensor): (batch, in_channel, 1, time). |
|
""" |
|
return self.identity_conv(x) |
|
|
|
|
|
class BPULinear(torch.nn.Module): |
|
"""Refactor torch.nn.Linear or pointwise_conv""" |
|
|
|
def __init__(self, module, is_pointwise_conv=False): |
|
super().__init__() |
|
|
|
original = copy.deepcopy(module) |
|
self.idim = module.weight.size(1) |
|
self.odim = module.weight.size(0) |
|
self.is_pointwise_conv = is_pointwise_conv |
|
|
|
|
|
self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1) |
|
if is_pointwise_conv: |
|
|
|
self.linear.weight = torch.nn.Parameter(module.weight.unsqueeze(-1)) |
|
else: |
|
|
|
self.linear.weight = torch.nn.Parameter( |
|
module.weight.unsqueeze(2).unsqueeze(3) |
|
) |
|
self.linear.bias = module.bias |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
random_data = torch.randn(1, 8, self.idim) |
|
if self.is_pointwise_conv: |
|
random_data = random_data.transpose(1, 2) |
|
original_result = module(random_data) |
|
if self.is_pointwise_conv: |
|
random_data = random_data.transpose(1, 2) |
|
original_result = original_result.transpose(1, 2) |
|
random_data = random_data.transpose(1, 2).unsqueeze(2) |
|
new_result = self.forward(random_data) |
|
np.testing.assert_allclose( |
|
to_numpy(original_result), |
|
to_numpy(new_result.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Linear with 4-D dataflow. |
|
Args: |
|
x (torch.Tensor): (batch, in_channel, 1, time) |
|
Returns: |
|
(torch.Tensor): (batch, out_channel, 1, time). |
|
""" |
|
return self.linear(x) |
|
|
|
|
|
class BPUGlobalCMVN(torch.nn.Module): |
|
"""Refactor wenet/transformer/cmvn.py::GlobalCMVN""" |
|
|
|
def __init__(self, module): |
|
super().__init__() |
|
|
|
self.norm_var = module.norm_var |
|
|
|
|
|
self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0) |
|
self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""CMVN with 4-D dataflow. |
|
Args: |
|
x (torch.Tensor): (batch, 1, mel_dim, time) |
|
Returns: |
|
(torch.Tensor): normalized feature with same shape. |
|
""" |
|
x = x - self.mean |
|
if self.norm_var: |
|
x = x * self.istd |
|
return x |
|
|
|
|
|
class BPUConv2dSubsampling8(torch.nn.Module): |
|
"""Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8 |
|
|
|
NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding |
|
""" |
|
|
|
def __init__(self, module): |
|
super().__init__() |
|
|
|
original = copy.deepcopy(module) |
|
self.right_context = module.right_context |
|
self.subsampling_rate = module.subsampling_rate |
|
assert isinstance(module.pos_enc, NoPositionalEncoding) |
|
|
|
|
|
|
|
|
|
self.conv = module.conv |
|
for idx in [0, 2, 4]: |
|
self.conv[idx].weight = torch.nn.Parameter( |
|
module.conv[idx].weight.transpose(2, 3) |
|
) |
|
|
|
|
|
|
|
|
|
self.linear = torch.nn.ModuleList() |
|
odim = module.linear.weight.size(0) |
|
freq = module.linear.weight.size(1) // odim |
|
self.odim, self.freq = odim, freq |
|
weight = module.linear.weight.reshape( |
|
odim, odim, freq, 1 |
|
) |
|
self.split_size = [] |
|
num_split = (freq - 1) // 7 + 1 |
|
slice_begin = 0 |
|
for idx in range(num_split): |
|
kernel_size = min(freq, (idx + 1) * 7) - idx * 7 |
|
conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1), (kernel_size, 1)) |
|
conv_ele.weight = torch.nn.Parameter( |
|
weight[:, :, slice_begin : slice_begin + kernel_size, :] |
|
) |
|
conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias)) |
|
self.linear.append(conv_ele) |
|
self.split_size.append(kernel_size) |
|
slice_begin += kernel_size |
|
self.linear[0].bias = torch.nn.Parameter(module.linear.bias) |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
random_data = torch.randn(1, 67, 80) |
|
mask = torch.zeros(1, 1, 67) |
|
original_result, _, _ = module(random_data, mask) |
|
random_data = random_data.transpose(1, 2).unsqueeze(0) |
|
new_result = self.forward(random_data) |
|
np.testing.assert_allclose( |
|
to_numpy(original_result), |
|
to_numpy(new_result.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Subsample x with 4-D dataflow. |
|
Args: |
|
x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time). |
|
|
|
Returns: |
|
torch.Tensor: Subsampled tensor (#batch, odim, 1, time'), |
|
where time' = time // 8. |
|
""" |
|
x = self.conv(x) |
|
x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3)) |
|
x = torch.split(x, self.split_size, dim=2) |
|
for idx, (x_part, layer) in enumerate(zip(x, self.linear)): |
|
x_out += layer(x_part) |
|
return x_out |
|
|
|
|
|
class BPUMultiHeadedAttention(torch.nn.Module): |
|
"""Refactor wenet/transformer/attention.py::MultiHeadedAttention |
|
|
|
NOTE(xcsong): Only support attention_class == MultiHeadedAttention, |
|
we do not consider RelPositionMultiHeadedAttention currently. |
|
""" |
|
|
|
def __init__(self, module, chunk_size, left_chunks): |
|
super().__init__() |
|
|
|
original = copy.deepcopy(module) |
|
self.d_k = module.d_k |
|
self.h = module.h |
|
n_feat = self.d_k * self.h |
|
self.chunk_size = chunk_size |
|
self.left_chunks = left_chunks |
|
self.time = chunk_size * (left_chunks + 1) |
|
self.activation = torch.nn.Softmax(dim=-1) |
|
|
|
|
|
self.linear_q = BPULinear(module.linear_q) |
|
self.linear_k = BPULinear(module.linear_k) |
|
self.linear_v = BPULinear(module.linear_v) |
|
self.linear_out = BPULinear(module.linear_out) |
|
|
|
self.register_buffer( |
|
"denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k)) |
|
) |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
random_data = torch.randn(1, self.chunk_size, self.d_k * self.h) |
|
mask = torch.ones((1, self.h, self.chunk_size, self.time), dtype=torch.bool) |
|
cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks, self.d_k * 2) |
|
original_out, original_cache = module( |
|
random_data, |
|
random_data, |
|
random_data, |
|
mask[:, 0, :, :], |
|
torch.empty(0), |
|
cache, |
|
) |
|
random_data = random_data.transpose(1, 2).unsqueeze(2) |
|
cache = cache.reshape( |
|
1, self.h, self.d_k * 2, self.chunk_size * self.left_chunks |
|
) |
|
new_out, new_cache = self.forward( |
|
random_data, random_data, random_data, mask, cache |
|
) |
|
np.testing.assert_allclose( |
|
to_numpy(original_out), |
|
to_numpy(new_out.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
np.testing.assert_allclose( |
|
to_numpy(original_cache), |
|
to_numpy(new_cache.transpose(2, 3)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward( |
|
self, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
mask: torch.Tensor, |
|
cache: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Compute scaled dot product attention. |
|
|
|
Args: |
|
q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size). |
|
k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size). |
|
v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size). |
|
mask (torch.Tensor): Mask tensor, |
|
(#batch, head, chunk_size, cache_t + chunk_size). |
|
cache (torch.Tensor): Cache tensor |
|
(1, head, d_k * 2, cache_t), |
|
where `cache_t == chunk_size * left_chunks`. |
|
|
|
|
|
Returns: |
|
torch.Tensor: Output tensor (#batch, size, 1, chunk_size). |
|
torch.Tensor: Cache tensor |
|
(1, head, d_k * 2, cache_t + chunk_size) |
|
where `cache_t == chunk_size * left_chunks` |
|
""" |
|
|
|
q = self.linear_q(q) |
|
k = self.linear_k(k) |
|
v = self.linear_v(v) |
|
q = q.view(1, self.h, self.d_k, self.chunk_size) |
|
k = k.view(1, self.h, self.d_k, self.chunk_size) |
|
v = v.view(1, self.h, self.d_k, self.chunk_size) |
|
q = q.transpose(2, 3) |
|
k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2) |
|
k = torch.cat((k_cache, k), dim=3) |
|
v = torch.cat((v_cache, v), dim=3) |
|
new_cache = torch.cat((k, v), dim=2) |
|
|
|
scores = torch.matmul(q, k) * self.denom |
|
|
|
mask = mask.eq(0) |
|
scores = scores.masked_fill(mask, -float("inf")) |
|
attn = self.activation(scores).masked_fill(mask, 0.0) |
|
attn = attn.transpose(2, 3) |
|
x = torch.matmul(v, attn) |
|
x = x.view(1, self.d_k * self.h, 1, self.chunk_size) |
|
x_out = self.linear_out(x) |
|
return x_out, new_cache |
|
|
|
|
|
class BPUConvolution(torch.nn.Module): |
|
"""Refactor wenet/transformer/convolution.py::ConvolutionModule |
|
|
|
NOTE(xcsong): Only suport use_layer_norm == False |
|
""" |
|
|
|
def __init__(self, module): |
|
super().__init__() |
|
|
|
original = copy.deepcopy(module) |
|
self.lorder = module.lorder |
|
self.use_layer_norm = False |
|
self.activation = module.activation |
|
channels = module.pointwise_conv1.weight.size(1) |
|
self.channels = channels |
|
kernel_size = module.depthwise_conv.weight.size(2) |
|
assert module.use_layer_norm is False |
|
|
|
|
|
self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True) |
|
|
|
|
|
self.depthwise_conv = torch.nn.Conv2d( |
|
channels, channels, (1, kernel_size), stride=1, groups=channels |
|
) |
|
self.depthwise_conv.weight = torch.nn.Parameter( |
|
module.depthwise_conv.weight.unsqueeze(-2) |
|
) |
|
self.depthwise_conv.bias = torch.nn.Parameter(module.depthwise_conv.bias) |
|
|
|
|
|
self.norm = torch.nn.BatchNorm2d(channels) |
|
self.norm.training = False |
|
self.norm.num_features = module.norm.num_features |
|
self.norm.eps = module.norm.eps |
|
self.norm.momentum = module.norm.momentum |
|
self.norm.weight = torch.nn.Parameter(module.norm.weight) |
|
self.norm.bias = torch.nn.Parameter(module.norm.bias) |
|
self.norm.running_mean = module.norm.running_mean |
|
self.norm.running_var = module.norm.running_var |
|
|
|
|
|
self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True) |
|
|
|
|
|
self.identity = BPUIdentity(channels) |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
random_data = torch.randn(1, 8, self.channels) |
|
cache = torch.zeros((1, self.channels, self.lorder)) |
|
original_out, original_cache = module(random_data, cache=cache) |
|
random_data = random_data.transpose(1, 2).unsqueeze(2) |
|
cache = cache.unsqueeze(2) |
|
new_out, new_cache = self.forward(random_data, cache) |
|
np.testing.assert_allclose( |
|
to_numpy(original_out), |
|
to_numpy(new_out.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
np.testing.assert_allclose( |
|
to_numpy(original_cache), |
|
to_numpy(new_cache.squeeze(2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward( |
|
self, x: torch.Tensor, cache: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Compute convolution module. |
|
Args: |
|
x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size). |
|
cache (torch.Tensor): left context cache, it is only |
|
used in causal convolution (#batch, channels, 1, cache_t). |
|
Returns: |
|
torch.Tensor: Output tensor (#batch, channels, 1, chunk_size). |
|
torch.Tensor: Cache tensor (#batch, channels, 1, cache_t). |
|
""" |
|
|
|
x = torch.cat((self.identity(cache), self.identity(x)), dim=3) |
|
new_cache = x[:, :, :, -self.lorder :] |
|
|
|
|
|
x = self.pointwise_conv1(x) |
|
x = torch.nn.functional.glu(x, dim=1) |
|
|
|
|
|
x = self.depthwise_conv(x) |
|
x = self.activation(self.norm(x)) |
|
x = self.pointwise_conv2(x) |
|
return x, new_cache |
|
|
|
|
|
class BPUFFN(torch.nn.Module): |
|
"""Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward""" |
|
|
|
def __init__(self, module): |
|
super().__init__() |
|
|
|
original = copy.deepcopy(module) |
|
self.activation = module.activation |
|
|
|
|
|
self.w_1 = BPULinear(module.w_1) |
|
self.w_2 = BPULinear(module.w_2) |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
random_data = torch.randn(1, 8, self.w_1.idim) |
|
original_out = module(random_data) |
|
random_data = random_data.transpose(1, 2).unsqueeze(2) |
|
new_out = self.forward(random_data) |
|
np.testing.assert_allclose( |
|
to_numpy(original_out), |
|
to_numpy(new_out.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward function. |
|
|
|
Args: |
|
xs: input tensor (B, D, 1, L) |
|
Returns: |
|
output tensor, (B, D, 1, L) |
|
""" |
|
return self.w_2(self.activation(self.w_1(x))) |
|
|
|
|
|
class BPUConformerEncoderLayer(torch.nn.Module): |
|
"""Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer""" |
|
|
|
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): |
|
super().__init__() |
|
|
|
original = copy.deepcopy(module) |
|
self.size = module.size |
|
assert module.normalize_before is True |
|
assert module.concat_after is False |
|
|
|
|
|
self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron) |
|
self.self_attn = BPUMultiHeadedAttention( |
|
module.self_attn, chunk_size, left_chunks |
|
) |
|
self.conv_module = BPUConvolution(module.conv_module) |
|
self.feed_forward = BPUFFN(module.feed_forward) |
|
|
|
|
|
self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu) |
|
self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, ln_run_on_bpu) |
|
self.norm_ff_macron = BPULayerNorm( |
|
module.norm_ff_macaron, chunk_size, ln_run_on_bpu |
|
) |
|
self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size, ln_run_on_bpu) |
|
self.norm_final = BPULayerNorm(module.norm_final, chunk_size, ln_run_on_bpu) |
|
|
|
|
|
self.register_buffer( |
|
"ff_scale", torch.full((1, self.size, 1, 1), module.ff_scale) |
|
) |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
time1 = self.self_attn.chunk_size |
|
time2 = self.self_attn.time |
|
h, d_k = self.self_attn.h, self.self_attn.d_k |
|
random_x = torch.randn(1, time1, self.size) |
|
att_mask = torch.ones(1, h, time1, time2) |
|
att_cache = torch.zeros(1, h, time2 - time1, d_k * 2) |
|
cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder) |
|
original_x, _, original_att_cache, original_cnn_cache = module( |
|
random_x, |
|
att_mask[:, 0, :, :], |
|
torch.empty(0), |
|
att_cache=att_cache, |
|
cnn_cache=cnn_cache, |
|
) |
|
random_x = random_x.transpose(1, 2).unsqueeze(2) |
|
att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1) |
|
cnn_cache = cnn_cache.unsqueeze(2) |
|
new_x, new_att_cache, new_cnn_cache = self.forward( |
|
random_x, att_mask, att_cache, cnn_cache |
|
) |
|
np.testing.assert_allclose( |
|
to_numpy(original_att_cache), |
|
to_numpy(new_att_cache.transpose(2, 3)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
np.testing.assert_allclose( |
|
to_numpy(original_x), |
|
to_numpy(new_x.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
np.testing.assert_allclose( |
|
to_numpy(original_cnn_cache), |
|
to_numpy(new_cnn_cache.squeeze(2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
att_mask: torch.Tensor, |
|
att_cache: torch.Tensor, |
|
cnn_cache: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Compute encoded features. |
|
|
|
Args: |
|
x (torch.Tensor): (#batch, size, 1, chunk_size) |
|
att_mask (torch.Tensor): Mask tensor for the input |
|
(#batch, head, chunk_size, cache_t1 + chunk_size), |
|
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE |
|
(#batch=1, head, d_k * 2, cache_t1), head * d_k == size. |
|
cnn_cache (torch.Tensor): Convolution cache in conformer layer |
|
(#batch=1, size, 1, cache_t2) |
|
Returns: |
|
torch.Tensor: Output tensor (#batch, size, 1, chunk_size). |
|
torch.Tensor: att_cache tensor, |
|
(1, head, d_k * 2, cache_t1 + chunk_size). |
|
torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2). |
|
""" |
|
|
|
residual = x |
|
x = self.norm_ff_macron(x) |
|
x = residual + self.ff_scale * self.feed_forward_macaron(x) |
|
|
|
|
|
residual = x |
|
x = self.norm_mha(x) |
|
x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache) |
|
x = residual + x_att |
|
|
|
|
|
residual = x |
|
x = self.norm_conv(x) |
|
x, new_cnn_cache = self.conv_module(x, cnn_cache) |
|
x = residual + x |
|
|
|
|
|
residual = x |
|
x = self.norm_ff(x) |
|
x = residual + self.ff_scale * self.feed_forward(x) |
|
|
|
|
|
x = self.norm_final(x) |
|
|
|
return x, new_att_cache, new_cnn_cache |
|
|
|
|
|
class BPUConformerEncoder(torch.nn.Module): |
|
"""Refactor wenet/transformer/encoder.py::ConformerEncoder""" |
|
|
|
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): |
|
super().__init__() |
|
|
|
original = copy.deepcopy(module) |
|
output_size = module.output_size() |
|
self._output_size = module.output_size() |
|
self.after_norm = module.after_norm |
|
self.chunk_size = chunk_size |
|
self.left_chunks = left_chunks |
|
self.head = module.encoders[0].self_attn.h |
|
self.layers = len(module.encoders) |
|
|
|
|
|
self.global_cmvn = BPUGlobalCMVN(module.global_cmvn) |
|
self.embed = BPUConv2dSubsampling8(module.embed) |
|
self.encoders = torch.nn.ModuleList() |
|
for layer in module.encoders: |
|
self.encoders.append( |
|
BPUConformerEncoderLayer(layer, chunk_size, left_chunks, ln_run_on_bpu) |
|
) |
|
|
|
|
|
self.identity_cnncache = BPUIdentity(output_size) |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
time1 = self.encoders[0].self_attn.chunk_size |
|
time2 = self.encoders[0].self_attn.time |
|
layers = self.layers |
|
h, d_k = self.head, self.encoders[0].self_attn.d_k |
|
decoding_window = ( |
|
(self.chunk_size - 1) * module.embed.subsampling_rate |
|
+ module.embed.right_context |
|
+ 1 |
|
) |
|
lorder = self.encoders[0].conv_module.lorder |
|
random_x = torch.randn(1, decoding_window, 80) |
|
att_mask = torch.ones(1, h, time1, time2) |
|
att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2) |
|
cnn_cache = torch.zeros(layers, 1, self._output_size, lorder) |
|
orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk( |
|
random_x, |
|
0, |
|
time2 - time1, |
|
att_mask=att_mask[:, 0, :, :], |
|
att_cache=att_cache, |
|
cnn_cache=cnn_cache, |
|
) |
|
random_x = random_x.unsqueeze(0) |
|
att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1) |
|
cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder) |
|
new_x, new_att_cache, new_cnn_cache = self.forward( |
|
random_x, att_cache, cnn_cache, att_mask |
|
) |
|
caches = torch.split(new_att_cache, h, dim=1) |
|
caches = [c.transpose(2, 3) for c in caches] |
|
np.testing.assert_allclose( |
|
to_numpy(orig_att_cache), |
|
to_numpy(torch.cat(caches, dim=0)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
np.testing.assert_allclose( |
|
to_numpy(orig_x), |
|
to_numpy(new_x.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
np.testing.assert_allclose( |
|
to_numpy(orig_cnn_cache), |
|
to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward( |
|
self, |
|
xs: torch.Tensor, |
|
att_cache: torch.Tensor, |
|
cnn_cache: torch.Tensor, |
|
att_mask: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" Forward just one chunk |
|
|
|
Args: |
|
xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim), |
|
where `time == (chunk_size - 1) * subsample_rate + \ |
|
subsample.right_context + 1` |
|
att_cache (torch.Tensor): cache tensor for KEY & VALUE in |
|
transformer/conformer attention, with shape |
|
(1, head * elayers, d_k * 2, cache_t1), where |
|
`head * d_k == hidden-dim` and |
|
`cache_t1 == chunk_size * left_chunks`. |
|
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, |
|
(1, hidden-dim, elayers, cache_t2), where |
|
`cache_t2 == cnn.lorder - 1` |
|
att_mask (torch.Tensor): Mask tensor for the input |
|
(#batch, head, chunk_size, cache_t1 + chunk_size), |
|
|
|
Returns: |
|
torch.Tensor: output of current input xs, |
|
with shape (b=1, hidden-dim, 1, chunk_size). |
|
torch.Tensor: new attention cache required for next chunk, with |
|
same shape as the original att_cache. |
|
torch.Tensor: new conformer cnn cache required for next chunk, with |
|
same shape as the original cnn_cache. |
|
""" |
|
|
|
xs = xs.transpose(2, 3) |
|
xs = self.global_cmvn(xs) |
|
|
|
xs = self.embed(xs) |
|
|
|
att_cache = torch.split(att_cache, self.head, dim=1) |
|
cnn_cache = self.identity_cnncache(cnn_cache) |
|
cnn_cache = torch.split(cnn_cache, 1, dim=2) |
|
r_att_cache = [] |
|
r_cnn_cache = [] |
|
for i, layer in enumerate(self.encoders): |
|
xs, new_att_cache, new_cnn_cache = layer( |
|
xs, att_mask, att_cache=att_cache[i], cnn_cache=cnn_cache[i] |
|
) |
|
r_att_cache.append(new_att_cache[:, :, :, self.chunk_size :]) |
|
r_cnn_cache.append(new_cnn_cache) |
|
r_att_cache = torch.cat(r_att_cache, dim=1) |
|
r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2)) |
|
|
|
xs = xs.squeeze(2).transpose(1, 2).contiguous() |
|
xs = self.after_norm(xs) |
|
|
|
xs = xs.transpose(1, 2).contiguous().unsqueeze(2) |
|
|
|
return (xs, r_att_cache, r_cnn_cache) |
|
|
|
|
|
class BPUCTC(torch.nn.Module): |
|
"""Refactor wenet/transformer/ctc.py::CTC""" |
|
|
|
def __init__(self, module): |
|
super().__init__() |
|
|
|
original = copy.deepcopy(module) |
|
self.idim = module.ctc_lo.weight.size(1) |
|
num_class = module.ctc_lo.weight.size(0) |
|
|
|
|
|
|
|
self.ctc_lo = torch.nn.ModuleList() |
|
self.split_size = [] |
|
num_split = (num_class - 1) // 2048 + 1 |
|
for idx in range(num_split): |
|
out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048 |
|
conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1) |
|
self.ctc_lo.append(conv_ele) |
|
self.split_size.append(out_channel) |
|
orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0) |
|
orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0) |
|
for i, (w, b) in enumerate(zip(orig_weight, orig_bias)): |
|
w = w.unsqueeze(2).unsqueeze(3) |
|
self.ctc_lo[i].weight = torch.nn.Parameter(w) |
|
self.ctc_lo[i].bias = torch.nn.Parameter(b) |
|
|
|
self.check_equal(original) |
|
|
|
def check_equal(self, module): |
|
random_data = torch.randn(1, 100, self.idim) |
|
original_result = module.ctc_lo(random_data) |
|
random_data = random_data.transpose(1, 2).unsqueeze(2) |
|
new_result = self.forward(random_data) |
|
np.testing.assert_allclose( |
|
to_numpy(original_result), |
|
to_numpy(new_result.squeeze(2).transpose(1, 2)), |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""frame activations, without softmax. |
|
|
|
Args: |
|
Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size) |
|
Returns: |
|
torch.Tensor: (B, num_class, 1, chunk_size) |
|
""" |
|
out = [] |
|
for i, layer in enumerate(self.ctc_lo): |
|
out.append(layer(x)) |
|
out = torch.cat(out, dim=1) |
|
return out |
|
|
|
|
|
def export_encoder(asr_model, args): |
|
logger.info("Stage-1: export encoder") |
|
decode_window, mel_dim = args.decoding_window, args.feature_size |
|
encoder = BPUConformerEncoder( |
|
asr_model.encoder, |
|
args.chunk_size, |
|
args.num_decoding_left_chunks, |
|
args.ln_run_on_bpu, |
|
) |
|
encoder.eval() |
|
encoder_outpath = os.path.join(args.output_dir, "encoder.onnx") |
|
|
|
logger.info("Stage-1.1: prepare inputs for encoder") |
|
chunk = torch.randn((1, 1, decode_window, mel_dim)) |
|
required_cache_size = encoder.chunk_size * encoder.left_chunks |
|
kv_time = required_cache_size + encoder.chunk_size |
|
hidden, layers = encoder._output_size, len(encoder.encoders) |
|
head = encoder.encoders[0].self_attn.h |
|
d_k = hidden // head |
|
lorder = encoder.encoders[0].conv_module.lorder |
|
att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size) |
|
att_mask = torch.ones((1, head, encoder.chunk_size, kv_time)) |
|
att_mask[:, :, :, :required_cache_size] = 0 |
|
cnn_cache = torch.zeros((1, hidden, layers, lorder)) |
|
inputs = (chunk, att_cache, cnn_cache, att_mask) |
|
logger.info( |
|
"chunk.size(): {} att_cache.size(): {} " |
|
"cnn_cache.size(): {} att_mask.size(): {}".format( |
|
list(chunk.size()), |
|
list(att_cache.size()), |
|
list(cnn_cache.size()), |
|
list(att_mask.size()), |
|
) |
|
) |
|
|
|
logger.info("Stage-1.2: torch.onnx.export") |
|
|
|
|
|
attributes = {} |
|
attributes["input_name"] = "chunk;att_cache;cnn_cache;att_mask" |
|
attributes["output_name"] = "output;r_att_cache;r_cnn_cache" |
|
attributes["input_type"] = "featuremap;featuremap;featuremap;featuremap" |
|
attributes["norm_type"] = "no_preprocess;no_preprocess;no_preprocess;no_preprocess" |
|
attributes["input_layout_train"] = "NCHW;NCHW;NCHW;NCHW" |
|
attributes["input_layout_rt"] = "NCHW;NCHW;NCHW;NCHW" |
|
attributes[ |
|
"input_shape" |
|
] = "{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format( |
|
chunk.size(0), |
|
chunk.size(1), |
|
chunk.size(2), |
|
chunk.size(3), |
|
att_cache.size(0), |
|
att_cache.size(1), |
|
att_cache.size(2), |
|
att_cache.size(3), |
|
cnn_cache.size(0), |
|
cnn_cache.size(1), |
|
cnn_cache.size(2), |
|
cnn_cache.size(3), |
|
att_mask.size(0), |
|
att_mask.size(1), |
|
att_mask.size(2), |
|
att_mask.size(3), |
|
) |
|
torch.onnx.export( |
|
encoder, |
|
inputs, |
|
encoder_outpath, |
|
opset_version=11, |
|
export_params=True, |
|
do_constant_folding=True, |
|
input_names=attributes["input_name"].split(";"), |
|
output_names=attributes["output_name"].split(";"), |
|
dynamic_axes=None, |
|
verbose=False, |
|
) |
|
onnx_encoder = onnx.load(encoder_outpath) |
|
for k in vars(args): |
|
meta = onnx_encoder.metadata_props.add() |
|
meta.key, meta.value = str(k), str(getattr(args, k)) |
|
for k in attributes: |
|
meta = onnx_encoder.metadata_props.add() |
|
meta.key, meta.value = str(k), str(attributes[k]) |
|
onnx.checker.check_model(onnx_encoder) |
|
onnx.helper.printable_graph(onnx_encoder.graph) |
|
onnx.save(onnx_encoder, encoder_outpath) |
|
print_input_output_info(onnx_encoder, "onnx_encoder") |
|
logger.info("Export onnx_encoder, done! see {}".format(encoder_outpath)) |
|
|
|
logger.info("Stage-1.3: check onnx_encoder and torch_encoder") |
|
torch_output = [] |
|
torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask) |
|
torch_att_cache = copy.deepcopy(att_cache) |
|
torch_cnn_cache = copy.deepcopy(cnn_cache) |
|
for i in range(10): |
|
logger.info( |
|
"torch chunk-{}: {}, att_cache: {}, cnn_cache: {}" |
|
", att_mask: {}".format( |
|
i, |
|
list(torch_chunk.size()), |
|
list(torch_att_cache.size()), |
|
list(torch_cnn_cache.size()), |
|
list(torch_att_mask.size()), |
|
) |
|
) |
|
torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)) :] = 1 |
|
out, torch_att_cache, torch_cnn_cache = encoder( |
|
torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask |
|
) |
|
torch_output.append(out) |
|
torch_output = torch.cat(torch_output, dim=-1) |
|
|
|
onnx_output = [] |
|
onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask) |
|
onnx_att_cache = to_numpy(att_cache) |
|
onnx_cnn_cache = to_numpy(cnn_cache) |
|
ort_session = onnxruntime.InferenceSession(encoder_outpath) |
|
input_names = [node.name for node in onnx_encoder.graph.input] |
|
for i in range(10): |
|
logger.info( |
|
"onnx chunk-{}: {}, att_cache: {}, cnn_cache: {}," |
|
" att_mask: {}".format( |
|
i, |
|
onnx_chunk.shape, |
|
onnx_att_cache.shape, |
|
onnx_cnn_cache.shape, |
|
onnx_att_mask.shape, |
|
) |
|
) |
|
onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)) :] = 1 |
|
ort_inputs = { |
|
"chunk": onnx_chunk, |
|
"att_cache": onnx_att_cache, |
|
"cnn_cache": onnx_cnn_cache, |
|
"att_mask": onnx_att_mask, |
|
} |
|
ort_outs = ort_session.run(None, ort_inputs) |
|
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] |
|
onnx_output.append(ort_outs[0]) |
|
onnx_output = np.concatenate(onnx_output, axis=-1) |
|
|
|
np.testing.assert_allclose( |
|
to_numpy(torch_output), onnx_output, rtol=1e-03, atol=1e-04 |
|
) |
|
meta = ort_session.get_modelmeta() |
|
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) |
|
logger.info("Check onnx_encoder, pass!") |
|
return encoder, ort_session |
|
|
|
|
|
def export_ctc(asr_model, args): |
|
logger.info("Stage-2: export ctc") |
|
ctc = BPUCTC(asr_model.ctc).eval() |
|
ctc_outpath = os.path.join(args.output_dir, "ctc.onnx") |
|
|
|
logger.info("Stage-2.1: prepare inputs for ctc") |
|
hidden = torch.randn((1, args.output_size, 1, args.chunk_size)) |
|
|
|
logger.info("Stage-2.2: torch.onnx.export") |
|
|
|
|
|
attributes = {} |
|
attributes["input_name"], attributes["input_type"] = "hidden", "featuremap" |
|
attributes["norm_type"] = "no_preprocess" |
|
attributes["input_layout_train"] = "NCHW" |
|
attributes["input_layout_rt"] = "NCHW" |
|
attributes["input_shape"] = "{}x{}x{}x{}".format( |
|
hidden.size(0), |
|
hidden.size(1), |
|
hidden.size(2), |
|
hidden.size(3), |
|
) |
|
torch.onnx.export( |
|
ctc, |
|
hidden, |
|
ctc_outpath, |
|
opset_version=11, |
|
export_params=True, |
|
do_constant_folding=True, |
|
input_names=["hidden"], |
|
output_names=["probs"], |
|
dynamic_axes=None, |
|
verbose=False, |
|
) |
|
onnx_ctc = onnx.load(ctc_outpath) |
|
for k in vars(args): |
|
meta = onnx_ctc.metadata_props.add() |
|
meta.key, meta.value = str(k), str(getattr(args, k)) |
|
for k in attributes: |
|
meta = onnx_ctc.metadata_props.add() |
|
meta.key, meta.value = str(k), str(attributes[k]) |
|
onnx.checker.check_model(onnx_ctc) |
|
onnx.helper.printable_graph(onnx_ctc.graph) |
|
onnx.save(onnx_ctc, ctc_outpath) |
|
print_input_output_info(onnx_ctc, "onnx_ctc") |
|
logger.info("Export onnx_ctc, done! see {}".format(ctc_outpath)) |
|
|
|
logger.info("Stage-2.3: check onnx_ctc and torch_ctc") |
|
torch_output = ctc(hidden) |
|
ort_session = onnxruntime.InferenceSession(ctc_outpath) |
|
onnx_output = ort_session.run(None, {"hidden": to_numpy(hidden)}) |
|
|
|
np.testing.assert_allclose( |
|
to_numpy(torch_output), onnx_output[0], rtol=1e-03, atol=1e-04 |
|
) |
|
meta = ort_session.get_modelmeta() |
|
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) |
|
logger.info("Check onnx_ctc, pass!") |
|
return ctc, ort_session |
|
|
|
|
|
def export_decoder(asr_model, args): |
|
logger.info("Currently, Decoder is not supported.") |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.manual_seed(777) |
|
args = get_args() |
|
args.ln_run_on_bpu = False |
|
|
|
assert args.chunk_size > 0 |
|
assert args.num_decoding_left_chunks > 0 |
|
os.system("mkdir -p " + args.output_dir) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
|
with open(args.config, "r") as fin: |
|
configs = yaml.load(fin, Loader=yaml.FullLoader) |
|
|
|
model = init_model(configs) |
|
load_checkpoint(model, args.checkpoint) |
|
model.eval() |
|
print(model) |
|
|
|
args.feature_size = configs["input_dim"] |
|
args.output_size = model.encoder.output_size() |
|
args.decoding_window = ( |
|
(args.chunk_size - 1) * model.encoder.embed.subsampling_rate |
|
+ model.encoder.embed.right_context |
|
+ 1 |
|
) |
|
|
|
export_encoder(model, args) |
|
export_ctc(model, args) |
|
export_decoder(model, args) |
|
|