waveletdeboshir commited on
Commit
519a560
·
verified ·
1 Parent(s): 2133832

Upload 2 files

Browse files
Files changed (2) hide show
  1. encoder.py +601 -0
  2. gigaam_transformers.py +1 -1
encoder.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copied from https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py"""
2
+ import math
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+
9
+ try:
10
+ from flash_attn import flash_attn_func
11
+
12
+ IMPORT_FLASH = True
13
+ except Exception as err:
14
+ IMPORT_FLASH = False
15
+ IMPORT_FLASH_ERR = err
16
+
17
+ # from .utils import apply_masked_flash_attn, apply_rotary_pos_emb
18
+
19
+
20
+ def rtt_half(x: Tensor) -> Tensor:
21
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
22
+ return torch.cat([-x2, x1], dim=x1.ndim - 1)
23
+
24
+
25
+ def apply_rotary_pos_emb(
26
+ q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0
27
+ ) -> Tuple[Tensor, Tensor]:
28
+ """
29
+ Applies Rotary Position Embeddings to query and key tensors.
30
+ """
31
+ cos, sin = (
32
+ cos[offset : q.shape[0] + offset, ...],
33
+ sin[offset : q.shape[0] + offset, ...],
34
+ )
35
+ return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
36
+
37
+
38
+ def apply_masked_flash_attn(
39
+ q: Tensor,
40
+ k: Tensor,
41
+ v: Tensor,
42
+ mask: Tensor,
43
+ h: int,
44
+ d_k: int,
45
+ ) -> Tensor:
46
+ """
47
+ Applies Flash Attention with padding masks.
48
+ """
49
+
50
+ from einops import rearrange
51
+ from flash_attn import flash_attn_varlen_func
52
+ from flash_attn.bert_padding import pad_input, unpad_input
53
+
54
+ pad_mask = ~mask[:, 0, :]
55
+ b, t = pad_mask.shape
56
+ q = q.view(b, t, h * d_k)
57
+ k = k.view(b, t, h * d_k)
58
+ v = v.view(b, t, h * d_k)
59
+
60
+ q_unpad, indices_q, _, max_seqlen_q = unpad_input(q, pad_mask)[:4]
61
+ q_unpad = rearrange(q_unpad, "nnz (h d) -> nnz h d", h=h)
62
+
63
+ k_unpad = unpad_input(k, pad_mask)[0]
64
+ k_unpad = rearrange(k_unpad, "nnz (h d) -> nnz h d", h=h)
65
+
66
+ v_unpad = unpad_input(v, pad_mask)[0]
67
+ v_unpad = rearrange(v_unpad, "nnz (h d) -> nnz h d", h=h)
68
+
69
+ lengths_q = pad_mask.sum(1).to(torch.int32).to(q.device)
70
+ cu_seqlens_q = F.pad(lengths_q.cumsum(0), (1, 0), value=0).to(torch.int32)
71
+ max_seqlen_q = torch.max(lengths_q)
72
+
73
+ output_unpad = flash_attn_varlen_func(
74
+ q_unpad,
75
+ k_unpad,
76
+ v_unpad,
77
+ cu_seqlens_q,
78
+ cu_seqlens_q,
79
+ max_seqlen_q,
80
+ max_seqlen_q,
81
+ )
82
+
83
+ scores = pad_input(
84
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"),
85
+ indices_q,
86
+ b,
87
+ t,
88
+ )
89
+
90
+ return scores
91
+
92
+
93
+ class StridingSubsampling(nn.Module):
94
+ """
95
+ Strided Subsampling layer used to reduce the sequence length.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ subsampling_factor: int,
101
+ feat_in: int,
102
+ feat_out: int,
103
+ conv_channels: int,
104
+ ):
105
+ super().__init__()
106
+ self._sampling_num = int(math.log(subsampling_factor, 2))
107
+ self._stride = 2
108
+ self._kernel_size = 3
109
+ self._padding = (self._kernel_size - 1) // 2
110
+
111
+ layers: List[nn.Module] = []
112
+ in_channels = 1
113
+ for _ in range(self._sampling_num):
114
+ layers.append(
115
+ torch.nn.Conv2d(
116
+ in_channels=in_channels,
117
+ out_channels=conv_channels,
118
+ kernel_size=self._kernel_size,
119
+ stride=self._stride,
120
+ padding=self._padding,
121
+ )
122
+ )
123
+ layers.append(nn.ReLU())
124
+ in_channels = conv_channels
125
+
126
+ out_length = self.calc_output_length(torch.tensor(feat_in))
127
+ self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
128
+ self.conv = torch.nn.Sequential(*layers)
129
+
130
+ def calc_output_length(self, lengths: Tensor) -> Tensor:
131
+ """
132
+ Calculates the output length after applying the subsampling.
133
+ """
134
+ lengths = lengths.to(torch.float)
135
+ add_pad = 2 * self._padding - self._kernel_size
136
+ for _ in range(self._sampling_num):
137
+ lengths = torch.div(lengths + add_pad, self._stride) + 1.0
138
+ lengths = torch.floor(lengths)
139
+ return lengths.to(dtype=torch.int)
140
+
141
+ def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]:
142
+ x = self.conv(x.unsqueeze(1))
143
+ b, _, t, _ = x.size()
144
+ x = self.out(x.transpose(1, 2).reshape(b, t, -1))
145
+ return x, self.calc_output_length(lengths)
146
+
147
+
148
+ class MultiHeadAttention(nn.Module, ABC):
149
+ """
150
+ Base class of Multi-Head Attention Mechanisms.
151
+ """
152
+
153
+ def __init__(self, n_head: int, n_feat: int, flash_attn=False):
154
+ super().__init__()
155
+ assert n_feat % n_head == 0
156
+ self.d_k = n_feat // n_head
157
+ self.h = n_head
158
+ self.linear_q = nn.Linear(n_feat, n_feat)
159
+ self.linear_k = nn.Linear(n_feat, n_feat)
160
+ self.linear_v = nn.Linear(n_feat, n_feat)
161
+ self.linear_out = nn.Linear(n_feat, n_feat)
162
+ self.flash_attn = flash_attn
163
+ if self.flash_attn and not IMPORT_FLASH:
164
+ raise RuntimeError(
165
+ f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. "
166
+ "Please install flash_attn or use --no_flash flag. "
167
+ "If you have already done this, "
168
+ "--force-reinstall flag might be useful"
169
+ )
170
+
171
+ def forward_qkv(
172
+ self, query: Tensor, key: Tensor, value: Tensor
173
+ ) -> Tuple[Tensor, Tensor, Tensor]:
174
+ """
175
+ Projects the inputs into queries, keys, and values for multi-head attention.
176
+ """
177
+ b = query.size(0)
178
+ q = self.linear_q(query).view(b, -1, self.h, self.d_k)
179
+ k = self.linear_k(key).view(b, -1, self.h, self.d_k)
180
+ v = self.linear_v(value).view(b, -1, self.h, self.d_k)
181
+ if self.flash_attn:
182
+ return q, k, v
183
+ return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
184
+
185
+ def forward_attention(
186
+ self, value: Tensor, scores: Tensor, mask: Optional[Tensor]
187
+ ) -> Tensor:
188
+ """
189
+ Computes the scaled dot-product attention given the projected values and scores.
190
+ """
191
+ b = value.size(0)
192
+ if mask is not None:
193
+ mask = mask.unsqueeze(1)
194
+ scores = scores.masked_fill(mask, -10000.0)
195
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
196
+ else:
197
+ attn = torch.softmax(scores, dim=-1)
198
+ x = torch.matmul(attn, value)
199
+ x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k)
200
+ return self.linear_out(x)
201
+
202
+
203
+ class RelPositionMultiHeadAttention(MultiHeadAttention):
204
+ """
205
+ Relative Position Multi-Head Attention module.
206
+ """
207
+
208
+ def __init__(self, n_head: int, n_feat: int):
209
+ super().__init__(n_head, n_feat)
210
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
211
+ self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
212
+ self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
213
+
214
+ def rel_shift(self, x: Tensor) -> Tensor:
215
+ b, h, qlen, pos_len = x.size()
216
+ x = torch.nn.functional.pad(x, pad=(1, 0))
217
+ x = x.view(b, h, -1, qlen)
218
+ return x[:, :, 1:].view(b, h, qlen, pos_len)
219
+
220
+ def forward(
221
+ self,
222
+ query: Tensor,
223
+ key: Tensor,
224
+ value: Tensor,
225
+ pos_emb: Tensor,
226
+ mask: Optional[Tensor] = None,
227
+ ) -> Tensor:
228
+ q, k, v = self.forward_qkv(query, key, value)
229
+ q = q.transpose(1, 2)
230
+ p = self.linear_pos(pos_emb)
231
+ p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2)
232
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
233
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
234
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
235
+ matrix_bd = self.rel_shift(matrix_bd)
236
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
237
+ matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
238
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
239
+ return self.forward_attention(v, scores, mask)
240
+
241
+
242
+ class RotaryPositionMultiHeadAttention(MultiHeadAttention):
243
+ """
244
+ Rotary Position Multi-Head Attention module.
245
+ """
246
+
247
+ def forward(
248
+ self,
249
+ query: Tensor,
250
+ key: Tensor,
251
+ value: Tensor,
252
+ pos_emb: List[Tensor],
253
+ mask: Optional[Tensor] = None,
254
+ ) -> Tensor:
255
+ b, t, _ = value.size()
256
+ query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
257
+ key = key.transpose(0, 1).view(t, b, self.h, self.d_k)
258
+ value = value.transpose(0, 1).view(t, b, self.h, self.d_k)
259
+
260
+ cos, sin = pos_emb
261
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
262
+
263
+ q, k, v = self.forward_qkv(
264
+ query.view(t, b, self.h * self.d_k).transpose(0, 1),
265
+ key.view(t, b, self.h * self.d_k).transpose(0, 1),
266
+ value.view(t, b, self.h * self.d_k).transpose(0, 1),
267
+ )
268
+
269
+ if not self.flash_attn:
270
+ scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
271
+ out = self.forward_attention(v, scores, mask)
272
+ else:
273
+ if mask is None:
274
+ scores = flash_attn_func(q, k, v)
275
+ else:
276
+ scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
277
+
278
+ scores = scores.view(b, -1, self.h * self.d_k)
279
+ out = self.linear_out(scores)
280
+
281
+ return out
282
+
283
+
284
+ class PositionalEncoding(nn.Module, ABC):
285
+ """
286
+ Base class of Positional Encodings.
287
+ """
288
+
289
+ def __init__(self, dim: int, base: int):
290
+ super().__init__()
291
+ self.dim = dim
292
+ self.base = base
293
+
294
+ @abstractmethod
295
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
296
+ pass
297
+
298
+ def extend_pe(self, length: int, device: torch.device):
299
+ """
300
+ Extends the positional encoding buffer to process longer sequences.
301
+ """
302
+ pe = self.create_pe(length, device)
303
+ if pe is None:
304
+ return
305
+ if hasattr(self, "pe"):
306
+ self.pe = pe
307
+ else:
308
+ self.register_buffer("pe", pe, persistent=False)
309
+
310
+
311
+ class RelPositionalEmbedding(PositionalEncoding):
312
+ """
313
+ Relative Positional Embedding module.
314
+ """
315
+
316
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
317
+ """
318
+ Creates the relative positional encoding matrix.
319
+ """
320
+ if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
321
+ return None
322
+ positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1)
323
+ pos_length = positions.size(0)
324
+ pe = torch.zeros(pos_length, self.dim, device=positions.device)
325
+ div_term = torch.exp(
326
+ torch.arange(0, self.dim, 2, device=pe.device)
327
+ * -(math.log(10000.0) / self.dim)
328
+ )
329
+ pe[:, 0::2] = torch.sin(positions * div_term)
330
+ pe[:, 1::2] = torch.cos(positions * div_term)
331
+ return pe.unsqueeze(0)
332
+
333
+ def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
334
+ input_len = x.size(1)
335
+ center_pos = self.pe.size(1) // 2 + 1
336
+ start_pos = center_pos - input_len
337
+ end_pos = center_pos + input_len - 1
338
+ return x, self.pe[:, start_pos:end_pos]
339
+
340
+
341
+ class RotaryPositionalEmbedding(PositionalEncoding):
342
+ """
343
+ Rotary Positional Embedding module.
344
+ """
345
+
346
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
347
+ """
348
+ Creates or extends the rotary positional encoding matrix.
349
+ """
350
+ if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
351
+ return None
352
+ positions = torch.arange(0, length, dtype=torch.float32, device=device)
353
+ inv_freq = 1.0 / (
354
+ self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
355
+ )
356
+ t = torch.arange(length, device=positions.device).type_as(inv_freq)
357
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
358
+ emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
359
+ return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
360
+
361
+ def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
362
+ cos_emb = self.pe[0 : x.shape[1]]
363
+ half_pe = self.pe.shape[0] // 2
364
+ sin_emb = self.pe[half_pe : half_pe + x.shape[1]]
365
+ return x, [cos_emb, sin_emb]
366
+
367
+
368
+ class ConformerConvolution(nn.Module):
369
+ """
370
+ Conformer Convolution module.
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ d_model: int,
376
+ kernel_size: int,
377
+ ):
378
+ super().__init__()
379
+ assert (kernel_size - 1) % 2 == 0
380
+ self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
381
+ self.depthwise_conv = nn.Conv1d(
382
+ in_channels=d_model,
383
+ out_channels=d_model,
384
+ kernel_size=kernel_size,
385
+ padding=(kernel_size - 1) // 2,
386
+ groups=d_model,
387
+ bias=True,
388
+ )
389
+ self.batch_norm = nn.BatchNorm1d(d_model)
390
+ self.activation = nn.SiLU()
391
+ self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
392
+
393
+ def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor:
394
+ x = x.transpose(1, 2)
395
+ x = self.pointwise_conv1(x)
396
+ x = nn.functional.glu(x, dim=1)
397
+ if pad_mask is not None:
398
+ x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)
399
+ x = self.depthwise_conv(x)
400
+ x = self.batch_norm(x)
401
+ x = self.activation(x)
402
+ x = self.pointwise_conv2(x)
403
+ return x.transpose(1, 2)
404
+
405
+
406
+ class ConformerFeedForward(nn.Module):
407
+ """
408
+ Conformer Feed Forward module.
409
+ """
410
+
411
+ def __init__(self, d_model: int, d_ff: int, use_bias=True):
412
+ super().__init__()
413
+ self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias)
414
+ self.activation = nn.SiLU()
415
+ self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias)
416
+
417
+ def forward(self, x: Tensor) -> Tensor:
418
+ return self.linear2(self.activation(self.linear1(x)))
419
+
420
+
421
+ class ConformerLayer(nn.Module):
422
+ """
423
+ Conformer Layer module.
424
+ This module combines several submodules including feed forward networks,
425
+ depthwise separable convolution, and multi-head self-attention
426
+ to form a single Conformer block.
427
+ """
428
+
429
+ def __init__(
430
+ self,
431
+ d_model: int,
432
+ d_ff: int,
433
+ self_attention_model: str,
434
+ n_heads: int = 16,
435
+ conv_kernel_size: int = 31,
436
+ flash_attn: bool = False,
437
+ ):
438
+ super().__init__()
439
+ self.fc_factor = 0.5
440
+ self.norm_feed_forward1 = nn.LayerNorm(d_model)
441
+ self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
442
+ self.norm_conv = nn.LayerNorm(d_model)
443
+ self.conv = ConformerConvolution(
444
+ d_model=d_model,
445
+ kernel_size=conv_kernel_size,
446
+ )
447
+ self.norm_self_att = nn.LayerNorm(d_model)
448
+ if self_attention_model == "rotary":
449
+ self.self_attn: nn.Module = RotaryPositionMultiHeadAttention(
450
+ n_head=n_heads,
451
+ n_feat=d_model,
452
+ flash_attn=flash_attn,
453
+ )
454
+ else:
455
+ assert not flash_attn, "Not supported flash_attn for rel_pos"
456
+ self.self_attn = RelPositionMultiHeadAttention(
457
+ n_head=n_heads,
458
+ n_feat=d_model,
459
+ )
460
+ self.norm_feed_forward2 = nn.LayerNorm(d_model)
461
+ self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
462
+ self.norm_out = nn.LayerNorm(d_model)
463
+
464
+ def forward(
465
+ self,
466
+ x: Tensor,
467
+ pos_emb: Union[Tensor, List[Tensor]],
468
+ att_mask: Optional[Tensor] = None,
469
+ pad_mask: Optional[Tensor] = None,
470
+ ) -> Tensor:
471
+ residual = x
472
+ x = self.norm_feed_forward1(x)
473
+ x = self.feed_forward1(x)
474
+ residual = residual + x * self.fc_factor
475
+
476
+ x = self.norm_self_att(residual)
477
+ x = self.self_attn(x, x, x, pos_emb, mask=att_mask)
478
+ residual = residual + x
479
+
480
+ x = self.norm_conv(residual)
481
+ x = self.conv(x, pad_mask=pad_mask)
482
+ residual = residual + x
483
+
484
+ x = self.norm_feed_forward2(residual)
485
+ x = self.feed_forward2(x)
486
+ residual = residual + x * self.fc_factor
487
+
488
+ x = self.norm_out(residual)
489
+ return x
490
+
491
+
492
+ class ConformerEncoder(nn.Module):
493
+ """
494
+ Conformer Encoder module.
495
+ This module encapsulates the entire Conformer encoder architecture,
496
+ consisting of a StridingSubsampling layer, positional embeddings, and
497
+ a stack of Conformer Layers.
498
+ It serves as the main component responsible for processing speech features.
499
+ """
500
+
501
+ def __init__(
502
+ self,
503
+ feat_in: int = 64,
504
+ n_layers: int = 16,
505
+ d_model: int = 768,
506
+ subsampling_factor: int = 4,
507
+ ff_expansion_factor: int = 4,
508
+ self_attention_model: str = "rotary",
509
+ n_heads: int = 16,
510
+ pos_emb_max_len: int = 5000,
511
+ conv_kernel_size: int = 31,
512
+ flash_attn: bool = False,
513
+ ):
514
+ super().__init__()
515
+ self.feat_in = feat_in
516
+ assert self_attention_model in [
517
+ "rotary",
518
+ "rel_pos",
519
+ ], f"Not supported attn = {self_attention_model}"
520
+
521
+ self.pre_encode = StridingSubsampling(
522
+ subsampling_factor=subsampling_factor,
523
+ feat_in=feat_in,
524
+ feat_out=d_model,
525
+ conv_channels=d_model,
526
+ )
527
+
528
+ if self_attention_model == "rotary":
529
+ self.pos_enc: nn.Module = RotaryPositionalEmbedding(
530
+ d_model // n_heads, pos_emb_max_len
531
+ )
532
+ else:
533
+ self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len)
534
+
535
+ self.layers = nn.ModuleList()
536
+ for _ in range(n_layers):
537
+ layer = ConformerLayer(
538
+ d_model=d_model,
539
+ d_ff=d_model * ff_expansion_factor,
540
+ self_attention_model=self_attention_model,
541
+ n_heads=n_heads,
542
+ conv_kernel_size=conv_kernel_size,
543
+ flash_attn=flash_attn,
544
+ )
545
+ self.layers.append(layer)
546
+
547
+ self.pos_enc.extend_pe(pos_emb_max_len, next(self.parameters()).device)
548
+
549
+ def input_example(
550
+ self,
551
+ batch_size: int = 1,
552
+ seqlen: int = 200,
553
+ ):
554
+ device = next(self.parameters()).device
555
+ features = torch.zeros(batch_size, self.feat_in, seqlen)
556
+ feature_lengths = torch.full([batch_size], features.shape[-1])
557
+ return features.float().to(device), feature_lengths.to(device)
558
+
559
+ def input_names(self):
560
+ return ["audio_signal", "length"]
561
+
562
+ def output_names(self):
563
+ return ["encoded", "encoded_len"]
564
+
565
+ def dynamic_axes(self):
566
+ return {
567
+ "audio_signal": {0: "batch_size", 2: "seq_len"},
568
+ "length": {0: "batch_size"},
569
+ "encoded": {0: "batch_size", 1: "seq_len"},
570
+ "encoded_len": {0: "batch_size"},
571
+ }
572
+
573
+ def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]:
574
+ audio_signal, length = self.pre_encode(
575
+ x=audio_signal.transpose(1, 2), lengths=length
576
+ )
577
+
578
+ max_len = audio_signal.size(1)
579
+ audio_signal, pos_emb = self.pos_enc(x=audio_signal)
580
+
581
+ pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand(
582
+ length.size(0), -1
583
+ ) < length.unsqueeze(-1)
584
+
585
+ att_mask = None
586
+ if audio_signal.shape[0] > 1:
587
+ att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1])
588
+ att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2))
589
+ att_mask = ~att_mask
590
+
591
+ pad_mask = ~pad_mask
592
+
593
+ for layer in self.layers:
594
+ audio_signal = layer(
595
+ x=audio_signal,
596
+ pos_emb=pos_emb,
597
+ att_mask=att_mask,
598
+ pad_mask=pad_mask,
599
+ )
600
+
601
+ return audio_signal.transpose(1, 2), length
gigaam_transformers.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  import torch
5
  import torch.nn as nn
6
  import torchaudio
7
- from gigaam.encoder import ConformerEncoder
8
  from torch import Tensor
9
  from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
10
  from transformers.configuration_utils import PretrainedConfig
 
4
  import torch
5
  import torch.nn as nn
6
  import torchaudio
7
+ from .encoder import ConformerEncoder
8
  from torch import Tensor
9
  from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
10
  from transformers.configuration_utils import PretrainedConfig