Upload model
Browse files- config.json +1 -1
- model.safetensors +3 -0
- modeling_bestrq_conformer.py +770 -0
config.json
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
{
|
2 |
"activation_dropout": 0.0,
|
3 |
"architectures": [
|
4 |
-
"
|
5 |
],
|
6 |
"attention_dropout": 0.0,
|
7 |
"auto_map": {
|
|
|
1 |
{
|
2 |
"activation_dropout": 0.0,
|
3 |
"architectures": [
|
4 |
+
"MeralionBestRqModelForCTC"
|
5 |
],
|
6 |
"attention_dropout": 0.0,
|
7 |
"auto_map": {
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b902d5c175decdd5502800af2599b4d018c741160144e4e6a4596c82cd2fa333
|
3 |
+
size 2541162484
|
modeling_bestrq_conformer.py
ADDED
@@ -0,0 +1,770 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
from torch import nn
|
5 |
+
from typing import Optional, Tuple, Union
|
6 |
+
|
7 |
+
from transformers.modeling_utils import PreTrainedModel
|
8 |
+
from transformers.activations import ACT2FN
|
9 |
+
from transformers.modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput, CausalLMOutput
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
|
12 |
+
from .configuration_bestrq_conformer import MeralionBestRqConformerEncoderConfig
|
13 |
+
|
14 |
+
|
15 |
+
_HIDDEN_STATES_START_POSITION = 2
|
16 |
+
|
17 |
+
def lengths_to_padding_mask(lens: torch.LongTensor)-> torch.BoolTensor:
|
18 |
+
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
19 |
+
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
20 |
+
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
21 |
+
return mask
|
22 |
+
|
23 |
+
|
24 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
25 |
+
"""Make mask tensor containing indices of padded part.
|
26 |
+
|
27 |
+
See description of make_non_pad_mask.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
31 |
+
Returns:
|
32 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
33 |
+
|
34 |
+
Examples:
|
35 |
+
>>> lengths = [5, 3, 2]
|
36 |
+
>>> make_pad_mask(lengths)
|
37 |
+
masks = [[0, 0, 0, 0 ,0],
|
38 |
+
[0, 0, 0, 1, 1],
|
39 |
+
[0, 0, 1, 1, 1]]
|
40 |
+
"""
|
41 |
+
batch_size = lengths.size(0)
|
42 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
43 |
+
seq_range = torch.arange(0,
|
44 |
+
max_len,
|
45 |
+
dtype=torch.int64,
|
46 |
+
device=lengths.device)
|
47 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
48 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
49 |
+
mask = seq_range_expand >= seq_length_expand
|
50 |
+
return mask
|
51 |
+
|
52 |
+
|
53 |
+
class Conv2dSubsampling(nn.Module):
|
54 |
+
"""
|
55 |
+
Convolutional 2D subsampling (to 1/4 length)
|
56 |
+
For feature extraction/downsampling of input mel spectrogram
|
57 |
+
|
58 |
+
Args:
|
59 |
+
in_channels (int): Number of channels in the input image
|
60 |
+
out_channels (int): Number of channels produced by the convolution
|
61 |
+
|
62 |
+
Inputs:
|
63 |
+
inputs (batch, time, dim): Tensor containing sequence of inputs
|
64 |
+
input_lengths (batch): Tensor containing input_length for each item in batch
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
outputs (batch, time, dim): Tensor produced by the convolution
|
68 |
+
output_lengths (batch): Tensor containing output_length for each item in batch
|
69 |
+
"""
|
70 |
+
def __init__(self, config):
|
71 |
+
super().__init__()
|
72 |
+
self.sequential = nn.Sequential(
|
73 |
+
nn.Conv2d(config.input_channels, config.hidden_size, kernel_size=3, stride=2),
|
74 |
+
nn.ReLU(),
|
75 |
+
nn.Conv2d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2),
|
76 |
+
nn.ReLU(),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
80 |
+
_, max_seq_len, _ = inputs.size()
|
81 |
+
outputs = self.sequential(inputs.unsqueeze(1))
|
82 |
+
batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
|
83 |
+
|
84 |
+
outputs = outputs.permute(0, 2, 1, 3)
|
85 |
+
outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)
|
86 |
+
|
87 |
+
subsampling_factor = int(max_seq_len * 1.0 / subsampled_lengths + 0.5)
|
88 |
+
input_len_0 = (input_lengths.float() / subsampling_factor).ceil().long()
|
89 |
+
input_len_1 = outputs.size(1) * torch.ones([input_lengths.size(0)]).long().to(
|
90 |
+
input_len_0.device
|
91 |
+
)
|
92 |
+
output_lengths = torch.min(input_len_0, input_len_1)
|
93 |
+
|
94 |
+
return outputs, output_lengths
|
95 |
+
|
96 |
+
|
97 |
+
class ConformerRelPositionalEmbedding(nn.Module):
|
98 |
+
"""Relative positional encoding module (new implementation).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
d_model: Embedding dimension.
|
102 |
+
dropout_rate: Dropout rate.
|
103 |
+
max_len: Maximum input length.
|
104 |
+
"""
|
105 |
+
def __init__(self, config):
|
106 |
+
super().__init__()
|
107 |
+
self.max_len = config.max_source_positions
|
108 |
+
self.d_model = config.hidden_size
|
109 |
+
self.pe = None
|
110 |
+
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
111 |
+
|
112 |
+
def extend_pe(self, x):
|
113 |
+
"""Reset the positional encodings."""
|
114 |
+
if self.pe is not None:
|
115 |
+
# self.pe contains both positive and negative parts
|
116 |
+
# the length of self.pe is 2 * input_len - 1
|
117 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
118 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
119 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
120 |
+
return
|
121 |
+
# Suppose `i` means to the position of query vector and `j` means the
|
122 |
+
# position of key vector. We use position relative positions when keys
|
123 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
124 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
125 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
126 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
127 |
+
div_term = torch.exp(
|
128 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
129 |
+
* -(math.log(10000.0) / self.d_model)
|
130 |
+
)
|
131 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
132 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
133 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
134 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
135 |
+
|
136 |
+
# Reserve the order of positive indices and concat both positive and
|
137 |
+
# negative indices. This is used to support the shifting trick
|
138 |
+
# as in https://arxiv.org/abs/1901.02860
|
139 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
140 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
141 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
142 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
143 |
+
|
144 |
+
def forward(self, x: torch.Tensor):
|
145 |
+
"""Add positional encoding.
|
146 |
+
Args:
|
147 |
+
x : Input tensor T X B X C.
|
148 |
+
Returns:
|
149 |
+
torch.Tensor: Encoded tensor T X B X C.
|
150 |
+
|
151 |
+
"""
|
152 |
+
x = x.transpose(0, 1) # Change TBC to BTC
|
153 |
+
self.extend_pe(x)
|
154 |
+
pos_emb = self.pe[
|
155 |
+
:,
|
156 |
+
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
|
157 |
+
]
|
158 |
+
pos_emb = pos_emb.transpose(0, 1) # change to TBC
|
159 |
+
return pos_emb
|
160 |
+
|
161 |
+
|
162 |
+
class ConformerRotaryPositionalEmbedding(nn.Module):
|
163 |
+
"""Rotary positional embedding
|
164 |
+
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(self, config):
|
168 |
+
super().__init__()
|
169 |
+
dim = config.hidden_size // config.num_attention_heads
|
170 |
+
base = config.rotary_embedding_base
|
171 |
+
|
172 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
173 |
+
self.register_buffer("inv_freq", inv_freq)
|
174 |
+
self.cached_sequence_length = None
|
175 |
+
self.cached_rotary_positional_embedding = None
|
176 |
+
|
177 |
+
def forward(self, hidden_states):
|
178 |
+
sequence_length = hidden_states.shape[1]
|
179 |
+
|
180 |
+
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
|
181 |
+
return self.cached_rotary_positional_embedding
|
182 |
+
|
183 |
+
self.cached_sequence_length = sequence_length
|
184 |
+
# Embeddings are computed in the dtype of the inv_freq constant
|
185 |
+
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
186 |
+
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
187 |
+
embeddings = torch.cat((freqs, freqs), dim=-1)
|
188 |
+
|
189 |
+
cos_embeddings = embeddings.cos()[:, None, None, :]
|
190 |
+
sin_embeddings = embeddings.sin()[:, None, None, :]
|
191 |
+
# Computed embeddings are cast to the dtype of the hidden state inputs
|
192 |
+
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
|
193 |
+
return self.cached_rotary_positional_embedding
|
194 |
+
|
195 |
+
|
196 |
+
class ConformerInputFeatureProjection(nn.Module):
|
197 |
+
def __init__(self, config):
|
198 |
+
super().__init__()
|
199 |
+
subsample_embed_dim = config.hidden_size * (((config.input_dim - 1) // 2 - 1) // 2)
|
200 |
+
#self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
201 |
+
self.projection = nn.Linear(subsample_embed_dim, config.hidden_size)
|
202 |
+
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
203 |
+
|
204 |
+
def forward(self, hidden_states):
|
205 |
+
"""
|
206 |
+
Args:
|
207 |
+
hidden_states: Input Tensor of shape T X B X C
|
208 |
+
Returns:
|
209 |
+
Tensor of shape T X B X C
|
210 |
+
"""
|
211 |
+
# non-projected hidden states are needed for quantization
|
212 |
+
#norm_hidden_states = self.layer_norm(hidden_states)
|
213 |
+
hidden_states = self.projection(hidden_states)
|
214 |
+
hidden_states = self.dropout(hidden_states)
|
215 |
+
return hidden_states
|
216 |
+
|
217 |
+
|
218 |
+
class ConformerFeedForward(nn.Module):
|
219 |
+
"""Positionwise feed forward layer used in conformer"""
|
220 |
+
def __init__(self, config):
|
221 |
+
super().__init__()
|
222 |
+
|
223 |
+
#self.layer_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-5, elementwise_affine=True)
|
224 |
+
|
225 |
+
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
226 |
+
|
227 |
+
self.intermediate_dense = nn.Linear(config.hidden_size, config.ffn_dim)
|
228 |
+
if isinstance(config.hidden_act, str):
|
229 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
230 |
+
else:
|
231 |
+
self.intermediate_act_fn = config.hidden_act
|
232 |
+
|
233 |
+
self.output_dense = nn.Linear(config.ffn_dim, config.hidden_size)
|
234 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
235 |
+
|
236 |
+
def forward(self, hidden_states):
|
237 |
+
"""
|
238 |
+
Args:
|
239 |
+
x: Input Tensor of shape T X B X C
|
240 |
+
Returns:
|
241 |
+
Tensor of shape T X B X C
|
242 |
+
"""
|
243 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
244 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
245 |
+
hidden_states = self.intermediate_dropout(hidden_states)
|
246 |
+
hidden_states = self.output_dense(hidden_states)
|
247 |
+
hidden_states = self.output_dropout(hidden_states)
|
248 |
+
return hidden_states
|
249 |
+
|
250 |
+
|
251 |
+
class ConformerConvolutionModule(nn.Module):
|
252 |
+
"""Convolution block used in the conformer block"""
|
253 |
+
|
254 |
+
def __init__(self, config):
|
255 |
+
super().__init__()
|
256 |
+
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
|
257 |
+
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
|
258 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
259 |
+
self.pointwise_conv1 = nn.Conv1d(
|
260 |
+
config.hidden_size,
|
261 |
+
2 * config.hidden_size,
|
262 |
+
kernel_size=1,
|
263 |
+
stride=1,
|
264 |
+
padding=0,
|
265 |
+
bias=False,
|
266 |
+
)
|
267 |
+
self.glu = nn.GLU(dim=1)
|
268 |
+
self.depthwise_conv = nn.Conv1d(
|
269 |
+
config.hidden_size,
|
270 |
+
config.hidden_size,
|
271 |
+
config.conv_depthwise_kernel_size,
|
272 |
+
stride=1,
|
273 |
+
padding=(config.conv_depthwise_kernel_size - 1) // 2,
|
274 |
+
groups=config.hidden_size,
|
275 |
+
bias=False,
|
276 |
+
)
|
277 |
+
self.batch_norm = nn.BatchNorm1d(config.hidden_size)
|
278 |
+
self.activation = ACT2FN[config.hidden_act]
|
279 |
+
self.pointwise_conv2 = nn.Conv1d(
|
280 |
+
config.hidden_size,
|
281 |
+
config.hidden_size,
|
282 |
+
kernel_size=1,
|
283 |
+
stride=1,
|
284 |
+
padding=0,
|
285 |
+
bias=False,
|
286 |
+
)
|
287 |
+
self.dropout = nn.Dropout(config.conformer_conv_dropout)
|
288 |
+
|
289 |
+
def forward(self, hidden_states):
|
290 |
+
"""
|
291 |
+
Args:
|
292 |
+
hidden_states: Input of shape B X T X C
|
293 |
+
Returns:
|
294 |
+
Tensor of shape B X T X C
|
295 |
+
"""
|
296 |
+
hidden_states = self.layer_norm(hidden_states)
|
297 |
+
hidden_states = hidden_states.transpose(1, 2)
|
298 |
+
|
299 |
+
# GLU mechanism
|
300 |
+
# => (batch, 2*channel, dim)
|
301 |
+
hidden_states = self.pointwise_conv1(hidden_states)
|
302 |
+
# => (batch, channel, dim)
|
303 |
+
hidden_states = self.glu(hidden_states)
|
304 |
+
|
305 |
+
# 1D Depthwise Conv
|
306 |
+
hidden_states = self.depthwise_conv(hidden_states)
|
307 |
+
hidden_states = self.batch_norm(hidden_states)
|
308 |
+
hidden_states = self.activation(hidden_states)
|
309 |
+
|
310 |
+
hidden_states = self.pointwise_conv2(hidden_states)
|
311 |
+
hidden_states = self.dropout(hidden_states)
|
312 |
+
hidden_states = hidden_states.transpose(1, 2)
|
313 |
+
return hidden_states
|
314 |
+
|
315 |
+
|
316 |
+
class ConformerSelfAttention(nn.Module):
|
317 |
+
"""ConformerSelfAttention object.
|
318 |
+
Can be enhanced with rotary or relative position embeddings.
|
319 |
+
"""
|
320 |
+
|
321 |
+
def __init__(self, config):
|
322 |
+
super().__init__()
|
323 |
+
|
324 |
+
self.head_size = config.hidden_size // config.num_attention_heads
|
325 |
+
self.num_heads = config.num_attention_heads
|
326 |
+
self.position_embeddings_type = config.position_embeddings_type
|
327 |
+
|
328 |
+
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
|
329 |
+
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
|
330 |
+
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
|
331 |
+
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
|
332 |
+
|
333 |
+
self.dropout = nn.Dropout(p=config.attention_dropout)
|
334 |
+
|
335 |
+
if self.position_embeddings_type == "relative":
|
336 |
+
# linear transformation for positional encoding
|
337 |
+
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
338 |
+
# these two learnable bias are used in matrix c and matrix d
|
339 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
340 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
|
341 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
|
342 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u) ##
|
343 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v) ##
|
344 |
+
|
345 |
+
def forward(
|
346 |
+
self,
|
347 |
+
hidden_states: torch.Tensor, #[T, B, C]
|
348 |
+
attention_mask: Optional[torch.Tensor] = None,
|
349 |
+
relative_position_embeddings: Optional[torch.Tensor] = None, #[T, B, C]
|
350 |
+
output_attentions: bool = False,
|
351 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
352 |
+
# self-attention mechanism
|
353 |
+
hidden_states = hidden_states.transpose(0, 1) #[B, T, C]
|
354 |
+
relative_position_embeddings = relative_position_embeddings.transpose(0, 1) #[B, T, C]
|
355 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
356 |
+
|
357 |
+
# make sure query/key states can be != value states
|
358 |
+
query_key_states = hidden_states
|
359 |
+
value_states = hidden_states
|
360 |
+
|
361 |
+
if self.position_embeddings_type == "rotary":
|
362 |
+
if relative_position_embeddings is None:
|
363 |
+
raise ValueError(
|
364 |
+
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
|
365 |
+
)
|
366 |
+
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
|
367 |
+
|
368 |
+
# project query_key_states and value_states
|
369 |
+
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
370 |
+
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
371 |
+
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
|
372 |
+
|
373 |
+
# => (batch, head, time1, d_k)
|
374 |
+
query = query.transpose(1, 2)
|
375 |
+
key = key.transpose(1, 2)
|
376 |
+
value = value.transpose(1, 2)
|
377 |
+
|
378 |
+
if self.position_embeddings_type == "relative":
|
379 |
+
if relative_position_embeddings is None:
|
380 |
+
raise ValueError(
|
381 |
+
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
|
382 |
+
" 'relative'"
|
383 |
+
)
|
384 |
+
# apply relative_position_embeddings to qk scores
|
385 |
+
# as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
|
386 |
+
scores = self._apply_relative_embeddings(
|
387 |
+
query=query, key=key, relative_position_embeddings=relative_position_embeddings
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
|
391 |
+
|
392 |
+
# apply attention_mask if necessary
|
393 |
+
if attention_mask is not None:
|
394 |
+
scores = scores.masked_fill(
|
395 |
+
attention_mask.unsqueeze(1).unsqueeze(2).to(bool),
|
396 |
+
float("-inf"), # (batch, head, time1, time2)
|
397 |
+
)
|
398 |
+
|
399 |
+
# => (batch, head, time1, time2)
|
400 |
+
probs = torch.softmax(scores, dim=-1)
|
401 |
+
probs = self.dropout(probs)
|
402 |
+
|
403 |
+
# => (batch, head, time1, d_k)
|
404 |
+
hidden_states = torch.matmul(probs, value)
|
405 |
+
|
406 |
+
# => (batch, time1, hidden_size)
|
407 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
|
408 |
+
hidden_states = self.linear_out(hidden_states)
|
409 |
+
|
410 |
+
# => (time1, batch, hidden_size)
|
411 |
+
hidden_states = hidden_states.transpose(0, 1)
|
412 |
+
|
413 |
+
return hidden_states, probs
|
414 |
+
|
415 |
+
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
|
416 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
417 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
|
418 |
+
|
419 |
+
cos = relative_position_embeddings[0, :sequence_length, ...]
|
420 |
+
sin = relative_position_embeddings[1, :sequence_length, ...]
|
421 |
+
|
422 |
+
# rotate hidden_states with rotary embeddings
|
423 |
+
hidden_states = hidden_states.transpose(0, 1)
|
424 |
+
rotated_states_begin = hidden_states[..., : self.head_size // 2]
|
425 |
+
rotated_states_end = hidden_states[..., self.head_size // 2 :]
|
426 |
+
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
|
427 |
+
hidden_states = (hidden_states * cos) + (rotated_states * sin)
|
428 |
+
hidden_states = hidden_states.transpose(0, 1)
|
429 |
+
|
430 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
|
431 |
+
|
432 |
+
return hidden_states
|
433 |
+
|
434 |
+
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
|
435 |
+
# 1. project positional embeddings
|
436 |
+
# => (batch, head, d_k, 2*time1-1)
|
437 |
+
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
|
438 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
|
439 |
+
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size # (batch, 2*time1-1, head, d_k)
|
440 |
+
)
|
441 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
442 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) # (batch, head, d_k, 2*time1-1)
|
443 |
+
|
444 |
+
# 2. Add bias to query
|
445 |
+
# => (batch, head, time1, d_k)
|
446 |
+
query = query.transpose(1, 2) # (batch, time1, head, d_k)
|
447 |
+
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
448 |
+
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
449 |
+
|
450 |
+
# 3. attention score: first compute matrix a and matrix c
|
451 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
452 |
+
# => (batch, head, time1, time2)
|
453 |
+
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
454 |
+
|
455 |
+
# 4. then compute matrix b and matrix d
|
456 |
+
# => (batch, head, time1, 2*time1-1)
|
457 |
+
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
|
458 |
+
|
459 |
+
# 5. shift matrix b and matrix d
|
460 |
+
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
|
461 |
+
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
|
462 |
+
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
|
463 |
+
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
|
464 |
+
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
|
465 |
+
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
|
466 |
+
|
467 |
+
# 6. sum matrices
|
468 |
+
# => (batch, head, time1, time2)
|
469 |
+
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
|
470 |
+
|
471 |
+
return scores
|
472 |
+
|
473 |
+
|
474 |
+
class ConformerEncoderLayer(nn.Module):
|
475 |
+
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
|
476 |
+
|
477 |
+
def __init__(self, config):
|
478 |
+
super().__init__()
|
479 |
+
embed_dim = config.hidden_size
|
480 |
+
dropout = config.attention_dropout
|
481 |
+
|
482 |
+
# Feed-forward 1
|
483 |
+
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
|
484 |
+
self.ffn1 = ConformerFeedForward(config)
|
485 |
+
|
486 |
+
# Self-Attention
|
487 |
+
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
|
488 |
+
self.self_attn_dropout = nn.Dropout(dropout)
|
489 |
+
self.self_attn = ConformerSelfAttention(config)
|
490 |
+
|
491 |
+
# Conformer Convolution
|
492 |
+
self.conv_module = ConformerConvolutionModule(config)
|
493 |
+
|
494 |
+
# Feed-forward 2
|
495 |
+
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
|
496 |
+
self.ffn2 = ConformerFeedForward(config)
|
497 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
498 |
+
|
499 |
+
def forward(
|
500 |
+
self,
|
501 |
+
hidden_states, # [T, B, C]
|
502 |
+
attention_mask: Optional[torch.Tensor] = None,
|
503 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
504 |
+
output_attentions: bool = False,
|
505 |
+
):
|
506 |
+
hidden_states = hidden_states
|
507 |
+
|
508 |
+
# 1. Feed-Forward 1 layer
|
509 |
+
residual = hidden_states
|
510 |
+
hidden_states = self.ffn1_layer_norm(hidden_states)
|
511 |
+
hidden_states = self.ffn1(hidden_states)
|
512 |
+
hidden_states = hidden_states * 0.5 + residual
|
513 |
+
residual = hidden_states
|
514 |
+
|
515 |
+
# 2. Self-Attention layer
|
516 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
517 |
+
hidden_states, attn_weights = self.self_attn(
|
518 |
+
hidden_states=hidden_states,
|
519 |
+
attention_mask=attention_mask,
|
520 |
+
relative_position_embeddings=relative_position_embeddings,
|
521 |
+
output_attentions=output_attentions,
|
522 |
+
)
|
523 |
+
hidden_states = self.self_attn_dropout(hidden_states)
|
524 |
+
hidden_states = hidden_states + residual
|
525 |
+
|
526 |
+
# 3. Convolutional Layer
|
527 |
+
residual = hidden_states
|
528 |
+
hidden_states = hidden_states.transpose(0, 1) # [T,B,C] to [B,T,C]
|
529 |
+
hidden_states = self.conv_module(hidden_states)
|
530 |
+
hidden_states = hidden_states.transpose(0, 1) # [B,T,C] to [T,B,C]
|
531 |
+
hidden_states = residual + hidden_states
|
532 |
+
|
533 |
+
# 4. Feed-Forward 2 Layer
|
534 |
+
residual = hidden_states
|
535 |
+
hidden_states = self.ffn2_layer_norm(hidden_states)
|
536 |
+
hidden_states = self.ffn2(hidden_states)
|
537 |
+
hidden_states = hidden_states * 0.5 + residual
|
538 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
539 |
+
|
540 |
+
return hidden_states, attn_weights
|
541 |
+
|
542 |
+
|
543 |
+
class ConformerEncoder(nn.Module):
|
544 |
+
def __init__(self, config):
|
545 |
+
super().__init__()
|
546 |
+
self.config = config
|
547 |
+
self.embed_scale = math.sqrt(config.hidden_size)
|
548 |
+
if config.no_scale_embedding:
|
549 |
+
self.embed_scale = 1.0
|
550 |
+
|
551 |
+
if config.position_embeddings_type == "relative":
|
552 |
+
self.embed_positions = ConformerRelPositionalEmbedding(config)
|
553 |
+
elif config.position_embeddings_type == "rotary":
|
554 |
+
self.embed_positions = ConformerRotaryPositionalEmbedding(config)
|
555 |
+
else:
|
556 |
+
self.embed_positions = None
|
557 |
+
|
558 |
+
self.input_projection = ConformerInputFeatureProjection(config) # [T,B,C]
|
559 |
+
|
560 |
+
self.layers = nn.ModuleList([ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
561 |
+
self.gradient_checkpointing = False
|
562 |
+
|
563 |
+
def forward(
|
564 |
+
self,
|
565 |
+
hidden_states, # conv_out
|
566 |
+
attention_mask=None, # encoder_padding_mask
|
567 |
+
output_attentions=False,
|
568 |
+
output_hidden_states=False,
|
569 |
+
return_dict=True,
|
570 |
+
):
|
571 |
+
all_hidden_states = () if output_hidden_states else None
|
572 |
+
all_self_attentions = () if output_attentions else None
|
573 |
+
|
574 |
+
hidden_states = self.embed_scale * hidden_states
|
575 |
+
|
576 |
+
if self.embed_positions is not None:
|
577 |
+
relative_position_embeddings = self.embed_positions(hidden_states) # [T,B,C]
|
578 |
+
else:
|
579 |
+
relative_position_embeddings = None
|
580 |
+
|
581 |
+
hidden_states = self.input_projection(hidden_states) # [T,B,C]
|
582 |
+
for i, layer in enumerate(self.layers):
|
583 |
+
if output_hidden_states:
|
584 |
+
all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) # [T,B,C] -> [B,T,C]
|
585 |
+
|
586 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
587 |
+
dropout_probability = torch.rand([])
|
588 |
+
|
589 |
+
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
590 |
+
if not skip_the_layer:
|
591 |
+
layer_outputs = layer(
|
592 |
+
hidden_states,
|
593 |
+
attention_mask=attention_mask,
|
594 |
+
relative_position_embeddings=relative_position_embeddings,
|
595 |
+
output_attentions=output_attentions,
|
596 |
+
)
|
597 |
+
hidden_states = layer_outputs[0]
|
598 |
+
|
599 |
+
if skip_the_layer:
|
600 |
+
layer_outputs = (None, None)
|
601 |
+
|
602 |
+
if output_attentions:
|
603 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
604 |
+
|
605 |
+
hidden_states = hidden_states.transpose(0, 1) # [B,T,C]
|
606 |
+
if output_hidden_states:
|
607 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
608 |
+
|
609 |
+
if not return_dict:
|
610 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
611 |
+
return BaseModelOutput(
|
612 |
+
last_hidden_state=hidden_states,
|
613 |
+
hidden_states=all_hidden_states,
|
614 |
+
attentions=all_self_attentions,
|
615 |
+
)
|
616 |
+
|
617 |
+
|
618 |
+
class MeralionBestRqModel(PreTrainedModel):
|
619 |
+
config_class = MeralionBestRqConformerEncoderConfig
|
620 |
+
base_model_prefix = "bestrq_encoder"
|
621 |
+
|
622 |
+
def __init__(self, config: MeralionBestRqConformerEncoderConfig):
|
623 |
+
super().__init__(config)
|
624 |
+
self.config = config
|
625 |
+
self.conv_subsample = Conv2dSubsampling(config)
|
626 |
+
|
627 |
+
self.encoder = ConformerEncoder(config)
|
628 |
+
|
629 |
+
# Initialize weights and apply final processing
|
630 |
+
self.post_init()
|
631 |
+
|
632 |
+
def forward(
|
633 |
+
self,
|
634 |
+
input_values: Optional[torch.Tensor], # [B,C,T]
|
635 |
+
attention_mask: Optional[torch.Tensor] = None,
|
636 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
637 |
+
output_attentions: Optional[bool] = None,
|
638 |
+
output_hidden_states: Optional[bool] = None,
|
639 |
+
return_dict: Optional[bool] = None,
|
640 |
+
input_lengths: Optional[torch.Tensor] = None,
|
641 |
+
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
642 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
643 |
+
output_hidden_states = (
|
644 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
645 |
+
)
|
646 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
647 |
+
|
648 |
+
input_values = input_values.transpose(2, 1) # [B,C,T] -> [B,T,C]
|
649 |
+
conv_outputs, output_lengths = self.conv_subsample(input_values, input_lengths) # returns [B,T,C]
|
650 |
+
x = conv_outputs.transpose(0, 1) # [T,B,C]
|
651 |
+
|
652 |
+
encoder_padding_mask = make_pad_mask(output_lengths, max_len=x.shape[0])
|
653 |
+
|
654 |
+
encoder_outputs = self.encoder(
|
655 |
+
x,
|
656 |
+
attention_mask=encoder_padding_mask,
|
657 |
+
output_attentions=output_attentions,
|
658 |
+
output_hidden_states=output_hidden_states,
|
659 |
+
return_dict=return_dict,
|
660 |
+
)
|
661 |
+
|
662 |
+
hidden_states = encoder_outputs[0]
|
663 |
+
|
664 |
+
if not return_dict:
|
665 |
+
return (hidden_states, conv_outputs) + encoder_outputs[1:]
|
666 |
+
|
667 |
+
output = Wav2Vec2BaseModelOutput(
|
668 |
+
last_hidden_state=hidden_states,
|
669 |
+
extract_features=conv_outputs,
|
670 |
+
hidden_states=encoder_outputs.hidden_states,
|
671 |
+
attentions=encoder_outputs.attentions,
|
672 |
+
)
|
673 |
+
output["output_lengths"] = output_lengths
|
674 |
+
return output
|
675 |
+
|
676 |
+
|
677 |
+
|
678 |
+
class MeralionBestRqModelForCTC(PreTrainedModel):
|
679 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
680 |
+
config_class = MeralionBestRqConformerEncoderConfig
|
681 |
+
base_model_prefix = "bestrq_encoder"
|
682 |
+
|
683 |
+
def __init__(self, config, target_lang: Optional[str] = None, **kwargs):
|
684 |
+
super().__init__(config)
|
685 |
+
|
686 |
+
self.bestrq_encoder = MeralionBestRqModel(config)
|
687 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
688 |
+
|
689 |
+
self.target_lang = target_lang
|
690 |
+
|
691 |
+
if config.vocab_size is None:
|
692 |
+
raise ValueError(
|
693 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
694 |
+
"does not define the vocabulary size of the language model head. Please "
|
695 |
+
"instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
696 |
+
"or define `vocab_size` of your model's configuration."
|
697 |
+
)
|
698 |
+
output_hidden_size = (
|
699 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
700 |
+
)
|
701 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
702 |
+
|
703 |
+
# Initialize weights and apply final processing
|
704 |
+
self.post_init()
|
705 |
+
|
706 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
707 |
+
def forward(
|
708 |
+
self,
|
709 |
+
input_values: Optional[torch.Tensor],
|
710 |
+
attention_mask: Optional[torch.Tensor] = None,
|
711 |
+
output_attentions: Optional[bool] = None,
|
712 |
+
output_hidden_states: Optional[bool] = None,
|
713 |
+
return_dict: Optional[bool] = None,
|
714 |
+
input_lengths: Optional[torch.Tensor] = None,
|
715 |
+
labels: Optional[torch.Tensor] = None,
|
716 |
+
) -> Union[Tuple, CausalLMOutput]:
|
717 |
+
r"""
|
718 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
719 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
720 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
721 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
722 |
+
config.vocab_size - 1]`.
|
723 |
+
"""
|
724 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
725 |
+
|
726 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
727 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
728 |
+
|
729 |
+
outputs = self.bestrq_encoder(
|
730 |
+
input_values,
|
731 |
+
output_attentions=output_attentions,
|
732 |
+
output_hidden_states=output_hidden_states,
|
733 |
+
return_dict=return_dict,
|
734 |
+
input_lengths=input_lengths
|
735 |
+
)
|
736 |
+
|
737 |
+
hidden_states = outputs.last_hidden_state
|
738 |
+
hidden_states = self.dropout(hidden_states)
|
739 |
+
|
740 |
+
logits = self.lm_head(hidden_states)
|
741 |
+
|
742 |
+
loss = None
|
743 |
+
if labels is not None:
|
744 |
+
# assuming that padded tokens are filled with -100
|
745 |
+
# when not being attended to
|
746 |
+
labels_mask = labels >= 0
|
747 |
+
target_lengths = labels_mask.sum(-1)
|
748 |
+
flattened_targets = labels.masked_select(labels_mask)
|
749 |
+
|
750 |
+
# ctc_loss doesn't support fp16
|
751 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
752 |
+
|
753 |
+
with torch.backends.cudnn.flags(enabled=False):
|
754 |
+
loss = nn.functional.ctc_loss(
|
755 |
+
log_probs,
|
756 |
+
flattened_targets,
|
757 |
+
outputs.output_lengths, #lengths after initial CNN downsampling
|
758 |
+
target_lengths,
|
759 |
+
blank=self.config.pad_token_id,
|
760 |
+
reduction=self.config.ctc_loss_reduction,
|
761 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
762 |
+
)
|
763 |
+
|
764 |
+
if not return_dict:
|
765 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
766 |
+
return ((loss,) + output) if loss is not None else output
|
767 |
+
|
768 |
+
return CausalLMOutput(
|
769 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
770 |
+
)
|