huzy0 commited on
Commit
f46b5d4
·
verified ·
1 Parent(s): a225830

Upload model

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. model.safetensors +3 -0
  3. modeling_bestrq_conformer.py +770 -0
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "activation_dropout": 0.0,
3
  "architectures": [
4
- "MeralionBestRqModel"
5
  ],
6
  "attention_dropout": 0.0,
7
  "auto_map": {
 
1
  {
2
  "activation_dropout": 0.0,
3
  "architectures": [
4
+ "MeralionBestRqModelForCTC"
5
  ],
6
  "attention_dropout": 0.0,
7
  "auto_map": {
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b902d5c175decdd5502800af2599b4d018c741160144e4e6a4596c82cd2fa333
3
+ size 2541162484
modeling_bestrq_conformer.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import math
4
+ from torch import nn
5
+ from typing import Optional, Tuple, Union
6
+
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.activations import ACT2FN
9
+ from transformers.modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput, CausalLMOutput
10
+ from safetensors.torch import load_file
11
+
12
+ from .configuration_bestrq_conformer import MeralionBestRqConformerEncoderConfig
13
+
14
+
15
+ _HIDDEN_STATES_START_POSITION = 2
16
+
17
+ def lengths_to_padding_mask(lens: torch.LongTensor)-> torch.BoolTensor:
18
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
19
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
20
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
21
+ return mask
22
+
23
+
24
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
25
+ """Make mask tensor containing indices of padded part.
26
+
27
+ See description of make_non_pad_mask.
28
+
29
+ Args:
30
+ lengths (torch.Tensor): Batch of lengths (B,).
31
+ Returns:
32
+ torch.Tensor: Mask tensor containing indices of padded part.
33
+
34
+ Examples:
35
+ >>> lengths = [5, 3, 2]
36
+ >>> make_pad_mask(lengths)
37
+ masks = [[0, 0, 0, 0 ,0],
38
+ [0, 0, 0, 1, 1],
39
+ [0, 0, 1, 1, 1]]
40
+ """
41
+ batch_size = lengths.size(0)
42
+ max_len = max_len if max_len > 0 else lengths.max().item()
43
+ seq_range = torch.arange(0,
44
+ max_len,
45
+ dtype=torch.int64,
46
+ device=lengths.device)
47
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
48
+ seq_length_expand = lengths.unsqueeze(-1)
49
+ mask = seq_range_expand >= seq_length_expand
50
+ return mask
51
+
52
+
53
+ class Conv2dSubsampling(nn.Module):
54
+ """
55
+ Convolutional 2D subsampling (to 1/4 length)
56
+ For feature extraction/downsampling of input mel spectrogram
57
+
58
+ Args:
59
+ in_channels (int): Number of channels in the input image
60
+ out_channels (int): Number of channels produced by the convolution
61
+
62
+ Inputs:
63
+ inputs (batch, time, dim): Tensor containing sequence of inputs
64
+ input_lengths (batch): Tensor containing input_length for each item in batch
65
+
66
+ Returns:
67
+ outputs (batch, time, dim): Tensor produced by the convolution
68
+ output_lengths (batch): Tensor containing output_length for each item in batch
69
+ """
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.sequential = nn.Sequential(
73
+ nn.Conv2d(config.input_channels, config.hidden_size, kernel_size=3, stride=2),
74
+ nn.ReLU(),
75
+ nn.Conv2d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2),
76
+ nn.ReLU(),
77
+ )
78
+
79
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
80
+ _, max_seq_len, _ = inputs.size()
81
+ outputs = self.sequential(inputs.unsqueeze(1))
82
+ batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
83
+
84
+ outputs = outputs.permute(0, 2, 1, 3)
85
+ outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)
86
+
87
+ subsampling_factor = int(max_seq_len * 1.0 / subsampled_lengths + 0.5)
88
+ input_len_0 = (input_lengths.float() / subsampling_factor).ceil().long()
89
+ input_len_1 = outputs.size(1) * torch.ones([input_lengths.size(0)]).long().to(
90
+ input_len_0.device
91
+ )
92
+ output_lengths = torch.min(input_len_0, input_len_1)
93
+
94
+ return outputs, output_lengths
95
+
96
+
97
+ class ConformerRelPositionalEmbedding(nn.Module):
98
+ """Relative positional encoding module (new implementation).
99
+
100
+ Args:
101
+ d_model: Embedding dimension.
102
+ dropout_rate: Dropout rate.
103
+ max_len: Maximum input length.
104
+ """
105
+ def __init__(self, config):
106
+ super().__init__()
107
+ self.max_len = config.max_source_positions
108
+ self.d_model = config.hidden_size
109
+ self.pe = None
110
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
111
+
112
+ def extend_pe(self, x):
113
+ """Reset the positional encodings."""
114
+ if self.pe is not None:
115
+ # self.pe contains both positive and negative parts
116
+ # the length of self.pe is 2 * input_len - 1
117
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
118
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
119
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
120
+ return
121
+ # Suppose `i` means to the position of query vector and `j` means the
122
+ # position of key vector. We use position relative positions when keys
123
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
124
+ pe_positive = torch.zeros(x.size(1), self.d_model)
125
+ pe_negative = torch.zeros(x.size(1), self.d_model)
126
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
127
+ div_term = torch.exp(
128
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
129
+ * -(math.log(10000.0) / self.d_model)
130
+ )
131
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
132
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
133
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
134
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
135
+
136
+ # Reserve the order of positive indices and concat both positive and
137
+ # negative indices. This is used to support the shifting trick
138
+ # as in https://arxiv.org/abs/1901.02860
139
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
140
+ pe_negative = pe_negative[1:].unsqueeze(0)
141
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
142
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
143
+
144
+ def forward(self, x: torch.Tensor):
145
+ """Add positional encoding.
146
+ Args:
147
+ x : Input tensor T X B X C.
148
+ Returns:
149
+ torch.Tensor: Encoded tensor T X B X C.
150
+
151
+ """
152
+ x = x.transpose(0, 1) # Change TBC to BTC
153
+ self.extend_pe(x)
154
+ pos_emb = self.pe[
155
+ :,
156
+ self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
157
+ ]
158
+ pos_emb = pos_emb.transpose(0, 1) # change to TBC
159
+ return pos_emb
160
+
161
+
162
+ class ConformerRotaryPositionalEmbedding(nn.Module):
163
+ """Rotary positional embedding
164
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
165
+ """
166
+
167
+ def __init__(self, config):
168
+ super().__init__()
169
+ dim = config.hidden_size // config.num_attention_heads
170
+ base = config.rotary_embedding_base
171
+
172
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
173
+ self.register_buffer("inv_freq", inv_freq)
174
+ self.cached_sequence_length = None
175
+ self.cached_rotary_positional_embedding = None
176
+
177
+ def forward(self, hidden_states):
178
+ sequence_length = hidden_states.shape[1]
179
+
180
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
181
+ return self.cached_rotary_positional_embedding
182
+
183
+ self.cached_sequence_length = sequence_length
184
+ # Embeddings are computed in the dtype of the inv_freq constant
185
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
186
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
187
+ embeddings = torch.cat((freqs, freqs), dim=-1)
188
+
189
+ cos_embeddings = embeddings.cos()[:, None, None, :]
190
+ sin_embeddings = embeddings.sin()[:, None, None, :]
191
+ # Computed embeddings are cast to the dtype of the hidden state inputs
192
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
193
+ return self.cached_rotary_positional_embedding
194
+
195
+
196
+ class ConformerInputFeatureProjection(nn.Module):
197
+ def __init__(self, config):
198
+ super().__init__()
199
+ subsample_embed_dim = config.hidden_size * (((config.input_dim - 1) // 2 - 1) // 2)
200
+ #self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
201
+ self.projection = nn.Linear(subsample_embed_dim, config.hidden_size)
202
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
203
+
204
+ def forward(self, hidden_states):
205
+ """
206
+ Args:
207
+ hidden_states: Input Tensor of shape T X B X C
208
+ Returns:
209
+ Tensor of shape T X B X C
210
+ """
211
+ # non-projected hidden states are needed for quantization
212
+ #norm_hidden_states = self.layer_norm(hidden_states)
213
+ hidden_states = self.projection(hidden_states)
214
+ hidden_states = self.dropout(hidden_states)
215
+ return hidden_states
216
+
217
+
218
+ class ConformerFeedForward(nn.Module):
219
+ """Positionwise feed forward layer used in conformer"""
220
+ def __init__(self, config):
221
+ super().__init__()
222
+
223
+ #self.layer_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-5, elementwise_affine=True)
224
+
225
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
226
+
227
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.ffn_dim)
228
+ if isinstance(config.hidden_act, str):
229
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
230
+ else:
231
+ self.intermediate_act_fn = config.hidden_act
232
+
233
+ self.output_dense = nn.Linear(config.ffn_dim, config.hidden_size)
234
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
235
+
236
+ def forward(self, hidden_states):
237
+ """
238
+ Args:
239
+ x: Input Tensor of shape T X B X C
240
+ Returns:
241
+ Tensor of shape T X B X C
242
+ """
243
+ hidden_states = self.intermediate_dense(hidden_states)
244
+ hidden_states = self.intermediate_act_fn(hidden_states)
245
+ hidden_states = self.intermediate_dropout(hidden_states)
246
+ hidden_states = self.output_dense(hidden_states)
247
+ hidden_states = self.output_dropout(hidden_states)
248
+ return hidden_states
249
+
250
+
251
+ class ConformerConvolutionModule(nn.Module):
252
+ """Convolution block used in the conformer block"""
253
+
254
+ def __init__(self, config):
255
+ super().__init__()
256
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
257
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
258
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
259
+ self.pointwise_conv1 = nn.Conv1d(
260
+ config.hidden_size,
261
+ 2 * config.hidden_size,
262
+ kernel_size=1,
263
+ stride=1,
264
+ padding=0,
265
+ bias=False,
266
+ )
267
+ self.glu = nn.GLU(dim=1)
268
+ self.depthwise_conv = nn.Conv1d(
269
+ config.hidden_size,
270
+ config.hidden_size,
271
+ config.conv_depthwise_kernel_size,
272
+ stride=1,
273
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
274
+ groups=config.hidden_size,
275
+ bias=False,
276
+ )
277
+ self.batch_norm = nn.BatchNorm1d(config.hidden_size)
278
+ self.activation = ACT2FN[config.hidden_act]
279
+ self.pointwise_conv2 = nn.Conv1d(
280
+ config.hidden_size,
281
+ config.hidden_size,
282
+ kernel_size=1,
283
+ stride=1,
284
+ padding=0,
285
+ bias=False,
286
+ )
287
+ self.dropout = nn.Dropout(config.conformer_conv_dropout)
288
+
289
+ def forward(self, hidden_states):
290
+ """
291
+ Args:
292
+ hidden_states: Input of shape B X T X C
293
+ Returns:
294
+ Tensor of shape B X T X C
295
+ """
296
+ hidden_states = self.layer_norm(hidden_states)
297
+ hidden_states = hidden_states.transpose(1, 2)
298
+
299
+ # GLU mechanism
300
+ # => (batch, 2*channel, dim)
301
+ hidden_states = self.pointwise_conv1(hidden_states)
302
+ # => (batch, channel, dim)
303
+ hidden_states = self.glu(hidden_states)
304
+
305
+ # 1D Depthwise Conv
306
+ hidden_states = self.depthwise_conv(hidden_states)
307
+ hidden_states = self.batch_norm(hidden_states)
308
+ hidden_states = self.activation(hidden_states)
309
+
310
+ hidden_states = self.pointwise_conv2(hidden_states)
311
+ hidden_states = self.dropout(hidden_states)
312
+ hidden_states = hidden_states.transpose(1, 2)
313
+ return hidden_states
314
+
315
+
316
+ class ConformerSelfAttention(nn.Module):
317
+ """ConformerSelfAttention object.
318
+ Can be enhanced with rotary or relative position embeddings.
319
+ """
320
+
321
+ def __init__(self, config):
322
+ super().__init__()
323
+
324
+ self.head_size = config.hidden_size // config.num_attention_heads
325
+ self.num_heads = config.num_attention_heads
326
+ self.position_embeddings_type = config.position_embeddings_type
327
+
328
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
329
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
330
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
331
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
332
+
333
+ self.dropout = nn.Dropout(p=config.attention_dropout)
334
+
335
+ if self.position_embeddings_type == "relative":
336
+ # linear transformation for positional encoding
337
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
338
+ # these two learnable bias are used in matrix c and matrix d
339
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
340
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
341
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
342
+ torch.nn.init.xavier_uniform_(self.pos_bias_u) ##
343
+ torch.nn.init.xavier_uniform_(self.pos_bias_v) ##
344
+
345
+ def forward(
346
+ self,
347
+ hidden_states: torch.Tensor, #[T, B, C]
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ relative_position_embeddings: Optional[torch.Tensor] = None, #[T, B, C]
350
+ output_attentions: bool = False,
351
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
352
+ # self-attention mechanism
353
+ hidden_states = hidden_states.transpose(0, 1) #[B, T, C]
354
+ relative_position_embeddings = relative_position_embeddings.transpose(0, 1) #[B, T, C]
355
+ batch_size, sequence_length, hidden_size = hidden_states.size()
356
+
357
+ # make sure query/key states can be != value states
358
+ query_key_states = hidden_states
359
+ value_states = hidden_states
360
+
361
+ if self.position_embeddings_type == "rotary":
362
+ if relative_position_embeddings is None:
363
+ raise ValueError(
364
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
365
+ )
366
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
367
+
368
+ # project query_key_states and value_states
369
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
370
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
371
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
372
+
373
+ # => (batch, head, time1, d_k)
374
+ query = query.transpose(1, 2)
375
+ key = key.transpose(1, 2)
376
+ value = value.transpose(1, 2)
377
+
378
+ if self.position_embeddings_type == "relative":
379
+ if relative_position_embeddings is None:
380
+ raise ValueError(
381
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
382
+ " 'relative'"
383
+ )
384
+ # apply relative_position_embeddings to qk scores
385
+ # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
386
+ scores = self._apply_relative_embeddings(
387
+ query=query, key=key, relative_position_embeddings=relative_position_embeddings
388
+ )
389
+ else:
390
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
391
+
392
+ # apply attention_mask if necessary
393
+ if attention_mask is not None:
394
+ scores = scores.masked_fill(
395
+ attention_mask.unsqueeze(1).unsqueeze(2).to(bool),
396
+ float("-inf"), # (batch, head, time1, time2)
397
+ )
398
+
399
+ # => (batch, head, time1, time2)
400
+ probs = torch.softmax(scores, dim=-1)
401
+ probs = self.dropout(probs)
402
+
403
+ # => (batch, head, time1, d_k)
404
+ hidden_states = torch.matmul(probs, value)
405
+
406
+ # => (batch, time1, hidden_size)
407
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
408
+ hidden_states = self.linear_out(hidden_states)
409
+
410
+ # => (time1, batch, hidden_size)
411
+ hidden_states = hidden_states.transpose(0, 1)
412
+
413
+ return hidden_states, probs
414
+
415
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
416
+ batch_size, sequence_length, hidden_size = hidden_states.size()
417
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
418
+
419
+ cos = relative_position_embeddings[0, :sequence_length, ...]
420
+ sin = relative_position_embeddings[1, :sequence_length, ...]
421
+
422
+ # rotate hidden_states with rotary embeddings
423
+ hidden_states = hidden_states.transpose(0, 1)
424
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
425
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
426
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
427
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
428
+ hidden_states = hidden_states.transpose(0, 1)
429
+
430
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
431
+
432
+ return hidden_states
433
+
434
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
435
+ # 1. project positional embeddings
436
+ # => (batch, head, d_k, 2*time1-1)
437
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
438
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
439
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size # (batch, 2*time1-1, head, d_k)
440
+ )
441
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
442
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) # (batch, head, d_k, 2*time1-1)
443
+
444
+ # 2. Add bias to query
445
+ # => (batch, head, time1, d_k)
446
+ query = query.transpose(1, 2) # (batch, time1, head, d_k)
447
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
448
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
449
+
450
+ # 3. attention score: first compute matrix a and matrix c
451
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
452
+ # => (batch, head, time1, time2)
453
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
454
+
455
+ # 4. then compute matrix b and matrix d
456
+ # => (batch, head, time1, 2*time1-1)
457
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
458
+
459
+ # 5. shift matrix b and matrix d
460
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
461
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
462
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
463
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
464
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
465
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
466
+
467
+ # 6. sum matrices
468
+ # => (batch, head, time1, time2)
469
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
470
+
471
+ return scores
472
+
473
+
474
+ class ConformerEncoderLayer(nn.Module):
475
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
476
+
477
+ def __init__(self, config):
478
+ super().__init__()
479
+ embed_dim = config.hidden_size
480
+ dropout = config.attention_dropout
481
+
482
+ # Feed-forward 1
483
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
484
+ self.ffn1 = ConformerFeedForward(config)
485
+
486
+ # Self-Attention
487
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
488
+ self.self_attn_dropout = nn.Dropout(dropout)
489
+ self.self_attn = ConformerSelfAttention(config)
490
+
491
+ # Conformer Convolution
492
+ self.conv_module = ConformerConvolutionModule(config)
493
+
494
+ # Feed-forward 2
495
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
496
+ self.ffn2 = ConformerFeedForward(config)
497
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states, # [T, B, C]
502
+ attention_mask: Optional[torch.Tensor] = None,
503
+ relative_position_embeddings: Optional[torch.Tensor] = None,
504
+ output_attentions: bool = False,
505
+ ):
506
+ hidden_states = hidden_states
507
+
508
+ # 1. Feed-Forward 1 layer
509
+ residual = hidden_states
510
+ hidden_states = self.ffn1_layer_norm(hidden_states)
511
+ hidden_states = self.ffn1(hidden_states)
512
+ hidden_states = hidden_states * 0.5 + residual
513
+ residual = hidden_states
514
+
515
+ # 2. Self-Attention layer
516
+ hidden_states = self.self_attn_layer_norm(hidden_states)
517
+ hidden_states, attn_weights = self.self_attn(
518
+ hidden_states=hidden_states,
519
+ attention_mask=attention_mask,
520
+ relative_position_embeddings=relative_position_embeddings,
521
+ output_attentions=output_attentions,
522
+ )
523
+ hidden_states = self.self_attn_dropout(hidden_states)
524
+ hidden_states = hidden_states + residual
525
+
526
+ # 3. Convolutional Layer
527
+ residual = hidden_states
528
+ hidden_states = hidden_states.transpose(0, 1) # [T,B,C] to [B,T,C]
529
+ hidden_states = self.conv_module(hidden_states)
530
+ hidden_states = hidden_states.transpose(0, 1) # [B,T,C] to [T,B,C]
531
+ hidden_states = residual + hidden_states
532
+
533
+ # 4. Feed-Forward 2 Layer
534
+ residual = hidden_states
535
+ hidden_states = self.ffn2_layer_norm(hidden_states)
536
+ hidden_states = self.ffn2(hidden_states)
537
+ hidden_states = hidden_states * 0.5 + residual
538
+ hidden_states = self.final_layer_norm(hidden_states)
539
+
540
+ return hidden_states, attn_weights
541
+
542
+
543
+ class ConformerEncoder(nn.Module):
544
+ def __init__(self, config):
545
+ super().__init__()
546
+ self.config = config
547
+ self.embed_scale = math.sqrt(config.hidden_size)
548
+ if config.no_scale_embedding:
549
+ self.embed_scale = 1.0
550
+
551
+ if config.position_embeddings_type == "relative":
552
+ self.embed_positions = ConformerRelPositionalEmbedding(config)
553
+ elif config.position_embeddings_type == "rotary":
554
+ self.embed_positions = ConformerRotaryPositionalEmbedding(config)
555
+ else:
556
+ self.embed_positions = None
557
+
558
+ self.input_projection = ConformerInputFeatureProjection(config) # [T,B,C]
559
+
560
+ self.layers = nn.ModuleList([ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
561
+ self.gradient_checkpointing = False
562
+
563
+ def forward(
564
+ self,
565
+ hidden_states, # conv_out
566
+ attention_mask=None, # encoder_padding_mask
567
+ output_attentions=False,
568
+ output_hidden_states=False,
569
+ return_dict=True,
570
+ ):
571
+ all_hidden_states = () if output_hidden_states else None
572
+ all_self_attentions = () if output_attentions else None
573
+
574
+ hidden_states = self.embed_scale * hidden_states
575
+
576
+ if self.embed_positions is not None:
577
+ relative_position_embeddings = self.embed_positions(hidden_states) # [T,B,C]
578
+ else:
579
+ relative_position_embeddings = None
580
+
581
+ hidden_states = self.input_projection(hidden_states) # [T,B,C]
582
+ for i, layer in enumerate(self.layers):
583
+ if output_hidden_states:
584
+ all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) # [T,B,C] -> [B,T,C]
585
+
586
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
587
+ dropout_probability = torch.rand([])
588
+
589
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
590
+ if not skip_the_layer:
591
+ layer_outputs = layer(
592
+ hidden_states,
593
+ attention_mask=attention_mask,
594
+ relative_position_embeddings=relative_position_embeddings,
595
+ output_attentions=output_attentions,
596
+ )
597
+ hidden_states = layer_outputs[0]
598
+
599
+ if skip_the_layer:
600
+ layer_outputs = (None, None)
601
+
602
+ if output_attentions:
603
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
604
+
605
+ hidden_states = hidden_states.transpose(0, 1) # [B,T,C]
606
+ if output_hidden_states:
607
+ all_hidden_states = all_hidden_states + (hidden_states,)
608
+
609
+ if not return_dict:
610
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
611
+ return BaseModelOutput(
612
+ last_hidden_state=hidden_states,
613
+ hidden_states=all_hidden_states,
614
+ attentions=all_self_attentions,
615
+ )
616
+
617
+
618
+ class MeralionBestRqModel(PreTrainedModel):
619
+ config_class = MeralionBestRqConformerEncoderConfig
620
+ base_model_prefix = "bestrq_encoder"
621
+
622
+ def __init__(self, config: MeralionBestRqConformerEncoderConfig):
623
+ super().__init__(config)
624
+ self.config = config
625
+ self.conv_subsample = Conv2dSubsampling(config)
626
+
627
+ self.encoder = ConformerEncoder(config)
628
+
629
+ # Initialize weights and apply final processing
630
+ self.post_init()
631
+
632
+ def forward(
633
+ self,
634
+ input_values: Optional[torch.Tensor], # [B,C,T]
635
+ attention_mask: Optional[torch.Tensor] = None,
636
+ mask_time_indices: Optional[torch.FloatTensor] = None,
637
+ output_attentions: Optional[bool] = None,
638
+ output_hidden_states: Optional[bool] = None,
639
+ return_dict: Optional[bool] = None,
640
+ input_lengths: Optional[torch.Tensor] = None,
641
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
642
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
643
+ output_hidden_states = (
644
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
645
+ )
646
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
647
+
648
+ input_values = input_values.transpose(2, 1) # [B,C,T] -> [B,T,C]
649
+ conv_outputs, output_lengths = self.conv_subsample(input_values, input_lengths) # returns [B,T,C]
650
+ x = conv_outputs.transpose(0, 1) # [T,B,C]
651
+
652
+ encoder_padding_mask = make_pad_mask(output_lengths, max_len=x.shape[0])
653
+
654
+ encoder_outputs = self.encoder(
655
+ x,
656
+ attention_mask=encoder_padding_mask,
657
+ output_attentions=output_attentions,
658
+ output_hidden_states=output_hidden_states,
659
+ return_dict=return_dict,
660
+ )
661
+
662
+ hidden_states = encoder_outputs[0]
663
+
664
+ if not return_dict:
665
+ return (hidden_states, conv_outputs) + encoder_outputs[1:]
666
+
667
+ output = Wav2Vec2BaseModelOutput(
668
+ last_hidden_state=hidden_states,
669
+ extract_features=conv_outputs,
670
+ hidden_states=encoder_outputs.hidden_states,
671
+ attentions=encoder_outputs.attentions,
672
+ )
673
+ output["output_lengths"] = output_lengths
674
+ return output
675
+
676
+
677
+
678
+ class MeralionBestRqModelForCTC(PreTrainedModel):
679
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
680
+ config_class = MeralionBestRqConformerEncoderConfig
681
+ base_model_prefix = "bestrq_encoder"
682
+
683
+ def __init__(self, config, target_lang: Optional[str] = None, **kwargs):
684
+ super().__init__(config)
685
+
686
+ self.bestrq_encoder = MeralionBestRqModel(config)
687
+ self.dropout = nn.Dropout(config.final_dropout)
688
+
689
+ self.target_lang = target_lang
690
+
691
+ if config.vocab_size is None:
692
+ raise ValueError(
693
+ f"You are trying to instantiate {self.__class__} with a configuration that "
694
+ "does not define the vocabulary size of the language model head. Please "
695
+ "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
696
+ "or define `vocab_size` of your model's configuration."
697
+ )
698
+ output_hidden_size = (
699
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
700
+ )
701
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
702
+
703
+ # Initialize weights and apply final processing
704
+ self.post_init()
705
+
706
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
707
+ def forward(
708
+ self,
709
+ input_values: Optional[torch.Tensor],
710
+ attention_mask: Optional[torch.Tensor] = None,
711
+ output_attentions: Optional[bool] = None,
712
+ output_hidden_states: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ input_lengths: Optional[torch.Tensor] = None,
715
+ labels: Optional[torch.Tensor] = None,
716
+ ) -> Union[Tuple, CausalLMOutput]:
717
+ r"""
718
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
719
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
720
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
721
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
722
+ config.vocab_size - 1]`.
723
+ """
724
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
725
+
726
+ if labels is not None and labels.max() >= self.config.vocab_size:
727
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
728
+
729
+ outputs = self.bestrq_encoder(
730
+ input_values,
731
+ output_attentions=output_attentions,
732
+ output_hidden_states=output_hidden_states,
733
+ return_dict=return_dict,
734
+ input_lengths=input_lengths
735
+ )
736
+
737
+ hidden_states = outputs.last_hidden_state
738
+ hidden_states = self.dropout(hidden_states)
739
+
740
+ logits = self.lm_head(hidden_states)
741
+
742
+ loss = None
743
+ if labels is not None:
744
+ # assuming that padded tokens are filled with -100
745
+ # when not being attended to
746
+ labels_mask = labels >= 0
747
+ target_lengths = labels_mask.sum(-1)
748
+ flattened_targets = labels.masked_select(labels_mask)
749
+
750
+ # ctc_loss doesn't support fp16
751
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
752
+
753
+ with torch.backends.cudnn.flags(enabled=False):
754
+ loss = nn.functional.ctc_loss(
755
+ log_probs,
756
+ flattened_targets,
757
+ outputs.output_lengths, #lengths after initial CNN downsampling
758
+ target_lengths,
759
+ blank=self.config.pad_token_id,
760
+ reduction=self.config.ctc_loss_reduction,
761
+ zero_infinity=self.config.ctc_zero_infinity,
762
+ )
763
+
764
+ if not return_dict:
765
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
766
+ return ((loss,) + output) if loss is not None else output
767
+
768
+ return CausalLMOutput(
769
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
770
+ )