waveletdeboshir commited on
Commit
9b9aeff
·
verified ·
1 Parent(s): 02a9456

Upload 9 files

Browse files
added_tokens.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "</s>": 35,
3
+ "<s>": 34
4
+ }
5
+
config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map":{
3
+ "AutoConfig": "gigaam_transformers.GigaAMConfig",
4
+ "AutoModel": "gigaam_transformers.GigaAMRNNTHF",
5
+ "AutoProcessor": "gigaam_transformers.GigaAMProcessor",
6
+ "AutoTokenizer": "gigaam_transformers.GigaAMTokenizer",
7
+ "AutoFeatureExtractor": "gigaam_transformers.GigaAMFeatureExtractor"
8
+ },
9
+
10
+ "encoder": {
11
+ "feat_in": 64,
12
+ "n_layers": 16,
13
+ "d_model": 768,
14
+ "subsampling_factor": 4,
15
+ "ff_expansion_factor": 4,
16
+ "self_attention_model": "rotary",
17
+ "pos_emb_max_len": 5000,
18
+ "n_heads": 16,
19
+ "conv_kernel_size": 31,
20
+ "flash_attn": false
21
+ },
22
+ "head": {
23
+ "decoder": {
24
+ "pred_hidden": 320,
25
+ "pred_rnn_layers": 1,
26
+ "num_classes": 34
27
+ },
28
+ "joint": {
29
+ "enc_hidden": 768,
30
+ "pred_hidden": 320,
31
+ "joint_hidden": 320,
32
+ "num_classes": 34
33
+ }
34
+ },
35
+ "labels": [
36
+ " ",
37
+ "а",
38
+ "б",
39
+ "в",
40
+ "г",
41
+ "д",
42
+ "е",
43
+ "ж",
44
+ "з",
45
+ "и",
46
+ "й",
47
+ "к",
48
+ "л",
49
+ "м",
50
+ "н",
51
+ "о",
52
+ "п",
53
+ "р",
54
+ "с",
55
+ "т",
56
+ "у",
57
+ "ф",
58
+ "х",
59
+ "ц",
60
+ "ч",
61
+ "ш",
62
+ "щ",
63
+ "ъ",
64
+ "ы",
65
+ "ь",
66
+ "э",
67
+ "ю",
68
+ "я"
69
+ ],
70
+ "blank_id": 33,
71
+ "max_symbols": 1000,
72
+ "model_type": "gigaam-rnnt"
73
+ }
encoder.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copied from https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py"""
2
+ import math
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+
9
+ # try:
10
+ # from flash_attn import flash_attn_func
11
+
12
+ # IMPORT_FLASH = True
13
+ # except Exception as err:
14
+ # IMPORT_FLASH = False
15
+ # IMPORT_FLASH_ERR = err
16
+
17
+ IMPORT_FLASH = False
18
+ IMPORT_FLASH_ERR = "Flash Attention not installed."
19
+
20
+ # from .utils import apply_masked_flash_attn, apply_rotary_pos_emb
21
+
22
+
23
+ def rtt_half(x: Tensor) -> Tensor:
24
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
25
+ return torch.cat([-x2, x1], dim=x1.ndim - 1)
26
+
27
+
28
+ def apply_rotary_pos_emb(
29
+ q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0
30
+ ) -> Tuple[Tensor, Tensor]:
31
+ """
32
+ Applies Rotary Position Embeddings to query and key tensors.
33
+ """
34
+ cos, sin = (
35
+ cos[offset : q.shape[0] + offset, ...],
36
+ sin[offset : q.shape[0] + offset, ...],
37
+ )
38
+ return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
39
+
40
+
41
+ # def apply_masked_flash_attn(
42
+ # q: Tensor,
43
+ # k: Tensor,
44
+ # v: Tensor,
45
+ # mask: Tensor,
46
+ # h: int,
47
+ # d_k: int,
48
+ # ) -> Tensor:
49
+ # """
50
+ # Applies Flash Attention with padding masks.
51
+ # """
52
+
53
+ # from einops import rearrange
54
+ # from flash_attn import flash_attn_varlen_func
55
+ # from flash_attn.bert_padding import pad_input, unpad_input
56
+
57
+ # pad_mask = ~mask[:, 0, :]
58
+ # b, t = pad_mask.shape
59
+ # q = q.view(b, t, h * d_k)
60
+ # k = k.view(b, t, h * d_k)
61
+ # v = v.view(b, t, h * d_k)
62
+
63
+ # q_unpad, indices_q, _, max_seqlen_q = unpad_input(q, pad_mask)[:4]
64
+ # q_unpad = rearrange(q_unpad, "nnz (h d) -> nnz h d", h=h)
65
+
66
+ # k_unpad = unpad_input(k, pad_mask)[0]
67
+ # k_unpad = rearrange(k_unpad, "nnz (h d) -> nnz h d", h=h)
68
+
69
+ # v_unpad = unpad_input(v, pad_mask)[0]
70
+ # v_unpad = rearrange(v_unpad, "nnz (h d) -> nnz h d", h=h)
71
+
72
+ # lengths_q = pad_mask.sum(1).to(torch.int32).to(q.device)
73
+ # cu_seqlens_q = F.pad(lengths_q.cumsum(0), (1, 0), value=0).to(torch.int32)
74
+ # max_seqlen_q = torch.max(lengths_q)
75
+
76
+ # output_unpad = flash_attn_varlen_func(
77
+ # q_unpad,
78
+ # k_unpad,
79
+ # v_unpad,
80
+ # cu_seqlens_q,
81
+ # cu_seqlens_q,
82
+ # max_seqlen_q,
83
+ # max_seqlen_q,
84
+ # )
85
+
86
+ # scores = pad_input(
87
+ # rearrange(output_unpad, "nnz h d -> nnz (h d)"),
88
+ # indices_q,
89
+ # b,
90
+ # t,
91
+ # )
92
+
93
+ # return scores
94
+
95
+
96
+ class StridingSubsampling(nn.Module):
97
+ """
98
+ Strided Subsampling layer used to reduce the sequence length.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ subsampling_factor: int,
104
+ feat_in: int,
105
+ feat_out: int,
106
+ conv_channels: int,
107
+ ):
108
+ super().__init__()
109
+ self._sampling_num = int(math.log(subsampling_factor, 2))
110
+ self._stride = 2
111
+ self._kernel_size = 3
112
+ self._padding = (self._kernel_size - 1) // 2
113
+
114
+ layers: List[nn.Module] = []
115
+ in_channels = 1
116
+ for _ in range(self._sampling_num):
117
+ layers.append(
118
+ torch.nn.Conv2d(
119
+ in_channels=in_channels,
120
+ out_channels=conv_channels,
121
+ kernel_size=self._kernel_size,
122
+ stride=self._stride,
123
+ padding=self._padding,
124
+ )
125
+ )
126
+ layers.append(nn.ReLU())
127
+ in_channels = conv_channels
128
+
129
+ out_length = self.calc_output_length(torch.tensor(feat_in))
130
+ self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
131
+ self.conv = torch.nn.Sequential(*layers)
132
+
133
+ def calc_output_length(self, lengths: Tensor) -> Tensor:
134
+ """
135
+ Calculates the output length after applying the subsampling.
136
+ """
137
+ lengths = lengths.to(torch.float)
138
+ add_pad = 2 * self._padding - self._kernel_size
139
+ for _ in range(self._sampling_num):
140
+ lengths = torch.div(lengths + add_pad, self._stride) + 1.0
141
+ lengths = torch.floor(lengths)
142
+ return lengths.to(dtype=torch.int)
143
+
144
+ def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]:
145
+ x = self.conv(x.unsqueeze(1))
146
+ b, _, t, _ = x.size()
147
+ x = self.out(x.transpose(1, 2).reshape(b, t, -1))
148
+ return x, self.calc_output_length(lengths)
149
+
150
+
151
+ class MultiHeadAttention(nn.Module, ABC):
152
+ """
153
+ Base class of Multi-Head Attention Mechanisms.
154
+ """
155
+
156
+ def __init__(self, n_head: int, n_feat: int, flash_attn=False):
157
+ super().__init__()
158
+ assert n_feat % n_head == 0
159
+ self.d_k = n_feat // n_head
160
+ self.h = n_head
161
+ self.linear_q = nn.Linear(n_feat, n_feat)
162
+ self.linear_k = nn.Linear(n_feat, n_feat)
163
+ self.linear_v = nn.Linear(n_feat, n_feat)
164
+ self.linear_out = nn.Linear(n_feat, n_feat)
165
+ self.flash_attn = flash_attn
166
+ if self.flash_attn and not IMPORT_FLASH:
167
+ raise RuntimeError(
168
+ f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. "
169
+ "Please install flash_attn or use --no_flash flag. "
170
+ "If you have already done this, "
171
+ "--force-reinstall flag might be useful"
172
+ )
173
+
174
+ def forward_qkv(
175
+ self, query: Tensor, key: Tensor, value: Tensor
176
+ ) -> Tuple[Tensor, Tensor, Tensor]:
177
+ """
178
+ Projects the inputs into queries, keys, and values for multi-head attention.
179
+ """
180
+ b = query.size(0)
181
+ q = self.linear_q(query).view(b, -1, self.h, self.d_k)
182
+ k = self.linear_k(key).view(b, -1, self.h, self.d_k)
183
+ v = self.linear_v(value).view(b, -1, self.h, self.d_k)
184
+ if self.flash_attn:
185
+ return q, k, v
186
+ return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
187
+
188
+ def forward_attention(
189
+ self, value: Tensor, scores: Tensor, mask: Optional[Tensor]
190
+ ) -> Tensor:
191
+ """
192
+ Computes the scaled dot-product attention given the projected values and scores.
193
+ """
194
+ b = value.size(0)
195
+ if mask is not None:
196
+ mask = mask.unsqueeze(1)
197
+ scores = scores.masked_fill(mask, -10000.0)
198
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
199
+ else:
200
+ attn = torch.softmax(scores, dim=-1)
201
+ x = torch.matmul(attn, value)
202
+ x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k)
203
+ return self.linear_out(x)
204
+
205
+
206
+ class RelPositionMultiHeadAttention(MultiHeadAttention):
207
+ """
208
+ Relative Position Multi-Head Attention module.
209
+ """
210
+
211
+ def __init__(self, n_head: int, n_feat: int):
212
+ super().__init__(n_head, n_feat)
213
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
214
+ self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
215
+ self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
216
+
217
+ def rel_shift(self, x: Tensor) -> Tensor:
218
+ b, h, qlen, pos_len = x.size()
219
+ x = torch.nn.functional.pad(x, pad=(1, 0))
220
+ x = x.view(b, h, -1, qlen)
221
+ return x[:, :, 1:].view(b, h, qlen, pos_len)
222
+
223
+ def forward(
224
+ self,
225
+ query: Tensor,
226
+ key: Tensor,
227
+ value: Tensor,
228
+ pos_emb: Tensor,
229
+ mask: Optional[Tensor] = None,
230
+ ) -> Tensor:
231
+ q, k, v = self.forward_qkv(query, key, value)
232
+ q = q.transpose(1, 2)
233
+ p = self.linear_pos(pos_emb)
234
+ p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2)
235
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
236
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
237
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
238
+ matrix_bd = self.rel_shift(matrix_bd)
239
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
240
+ matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
241
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
242
+ return self.forward_attention(v, scores, mask)
243
+
244
+
245
+ class RotaryPositionMultiHeadAttention(MultiHeadAttention):
246
+ """
247
+ Rotary Position Multi-Head Attention module.
248
+ """
249
+
250
+ def forward(
251
+ self,
252
+ query: Tensor,
253
+ key: Tensor,
254
+ value: Tensor,
255
+ pos_emb: List[Tensor],
256
+ mask: Optional[Tensor] = None,
257
+ ) -> Tensor:
258
+ b, t, _ = value.size()
259
+ query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
260
+ key = key.transpose(0, 1).view(t, b, self.h, self.d_k)
261
+ value = value.transpose(0, 1).view(t, b, self.h, self.d_k)
262
+
263
+ cos, sin = pos_emb
264
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
265
+
266
+ q, k, v = self.forward_qkv(
267
+ query.view(t, b, self.h * self.d_k).transpose(0, 1),
268
+ key.view(t, b, self.h * self.d_k).transpose(0, 1),
269
+ value.view(t, b, self.h * self.d_k).transpose(0, 1),
270
+ )
271
+
272
+ # if not self.flash_attn:
273
+ scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
274
+ out = self.forward_attention(v, scores, mask)
275
+ # else:
276
+ # if mask is None:
277
+ # scores = flash_attn_func(q, k, v)
278
+ # else:
279
+ # scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
280
+
281
+ # scores = scores.view(b, -1, self.h * self.d_k)
282
+ # out = self.linear_out(scores)
283
+
284
+ return out
285
+
286
+
287
+ class PositionalEncoding(nn.Module, ABC):
288
+ """
289
+ Base class of Positional Encodings.
290
+ """
291
+
292
+ def __init__(self, dim: int, base: int):
293
+ super().__init__()
294
+ self.dim = dim
295
+ self.base = base
296
+
297
+ @abstractmethod
298
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
299
+ pass
300
+
301
+ def extend_pe(self, length: int, device: torch.device):
302
+ """
303
+ Extends the positional encoding buffer to process longer sequences.
304
+ """
305
+ pe = self.create_pe(length, device)
306
+ if pe is None:
307
+ return
308
+ if hasattr(self, "pe"):
309
+ self.pe = pe
310
+ else:
311
+ self.register_buffer("pe", pe, persistent=False)
312
+
313
+
314
+ class RelPositionalEmbedding(PositionalEncoding):
315
+ """
316
+ Relative Positional Embedding module.
317
+ """
318
+
319
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
320
+ """
321
+ Creates the relative positional encoding matrix.
322
+ """
323
+ if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
324
+ return None
325
+ positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1)
326
+ pos_length = positions.size(0)
327
+ pe = torch.zeros(pos_length, self.dim, device=positions.device)
328
+ div_term = torch.exp(
329
+ torch.arange(0, self.dim, 2, device=pe.device)
330
+ * -(math.log(10000.0) / self.dim)
331
+ )
332
+ pe[:, 0::2] = torch.sin(positions * div_term)
333
+ pe[:, 1::2] = torch.cos(positions * div_term)
334
+ return pe.unsqueeze(0)
335
+
336
+ def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
337
+ input_len = x.size(1)
338
+ center_pos = self.pe.size(1) // 2 + 1
339
+ start_pos = center_pos - input_len
340
+ end_pos = center_pos + input_len - 1
341
+ return x, self.pe[:, start_pos:end_pos]
342
+
343
+
344
+ class RotaryPositionalEmbedding(PositionalEncoding):
345
+ """
346
+ Rotary Positional Embedding module.
347
+ """
348
+
349
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
350
+ """
351
+ Creates or extends the rotary positional encoding matrix.
352
+ """
353
+ if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
354
+ return None
355
+ positions = torch.arange(0, length, dtype=torch.float32, device=device)
356
+ inv_freq = 1.0 / (
357
+ self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
358
+ )
359
+ t = torch.arange(length, device=positions.device).type_as(inv_freq)
360
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
361
+ emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
362
+ return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
363
+
364
+ def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
365
+ cos_emb = self.pe[0 : x.shape[1]]
366
+ half_pe = self.pe.shape[0] // 2
367
+ sin_emb = self.pe[half_pe : half_pe + x.shape[1]]
368
+ return x, [cos_emb, sin_emb]
369
+
370
+
371
+ class ConformerConvolution(nn.Module):
372
+ """
373
+ Conformer Convolution module.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ d_model: int,
379
+ kernel_size: int,
380
+ ):
381
+ super().__init__()
382
+ assert (kernel_size - 1) % 2 == 0
383
+ self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
384
+ self.depthwise_conv = nn.Conv1d(
385
+ in_channels=d_model,
386
+ out_channels=d_model,
387
+ kernel_size=kernel_size,
388
+ padding=(kernel_size - 1) // 2,
389
+ groups=d_model,
390
+ bias=True,
391
+ )
392
+ self.batch_norm = nn.BatchNorm1d(d_model)
393
+ self.activation = nn.SiLU()
394
+ self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
395
+
396
+ def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor:
397
+ x = x.transpose(1, 2)
398
+ x = self.pointwise_conv1(x)
399
+ x = nn.functional.glu(x, dim=1)
400
+ if pad_mask is not None:
401
+ x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)
402
+ x = self.depthwise_conv(x)
403
+ x = self.batch_norm(x)
404
+ x = self.activation(x)
405
+ x = self.pointwise_conv2(x)
406
+ return x.transpose(1, 2)
407
+
408
+
409
+ class ConformerFeedForward(nn.Module):
410
+ """
411
+ Conformer Feed Forward module.
412
+ """
413
+
414
+ def __init__(self, d_model: int, d_ff: int, use_bias=True):
415
+ super().__init__()
416
+ self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias)
417
+ self.activation = nn.SiLU()
418
+ self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias)
419
+
420
+ def forward(self, x: Tensor) -> Tensor:
421
+ return self.linear2(self.activation(self.linear1(x)))
422
+
423
+
424
+ class ConformerLayer(nn.Module):
425
+ """
426
+ Conformer Layer module.
427
+ This module combines several submodules including feed forward networks,
428
+ depthwise separable convolution, and multi-head self-attention
429
+ to form a single Conformer block.
430
+ """
431
+
432
+ def __init__(
433
+ self,
434
+ d_model: int,
435
+ d_ff: int,
436
+ self_attention_model: str,
437
+ n_heads: int = 16,
438
+ conv_kernel_size: int = 31,
439
+ flash_attn: bool = False,
440
+ ):
441
+ super().__init__()
442
+ self.fc_factor = 0.5
443
+ self.norm_feed_forward1 = nn.LayerNorm(d_model)
444
+ self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
445
+ self.norm_conv = nn.LayerNorm(d_model)
446
+ self.conv = ConformerConvolution(
447
+ d_model=d_model,
448
+ kernel_size=conv_kernel_size,
449
+ )
450
+ self.norm_self_att = nn.LayerNorm(d_model)
451
+ if self_attention_model == "rotary":
452
+ self.self_attn: nn.Module = RotaryPositionMultiHeadAttention(
453
+ n_head=n_heads,
454
+ n_feat=d_model,
455
+ flash_attn=flash_attn,
456
+ )
457
+ else:
458
+ assert not flash_attn, "Not supported flash_attn for rel_pos"
459
+ self.self_attn = RelPositionMultiHeadAttention(
460
+ n_head=n_heads,
461
+ n_feat=d_model,
462
+ )
463
+ self.norm_feed_forward2 = nn.LayerNorm(d_model)
464
+ self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
465
+ self.norm_out = nn.LayerNorm(d_model)
466
+
467
+ def forward(
468
+ self,
469
+ x: Tensor,
470
+ pos_emb: Union[Tensor, List[Tensor]],
471
+ att_mask: Optional[Tensor] = None,
472
+ pad_mask: Optional[Tensor] = None,
473
+ ) -> Tensor:
474
+ residual = x
475
+ x = self.norm_feed_forward1(x)
476
+ x = self.feed_forward1(x)
477
+ residual = residual + x * self.fc_factor
478
+
479
+ x = self.norm_self_att(residual)
480
+ x = self.self_attn(x, x, x, pos_emb, mask=att_mask)
481
+ residual = residual + x
482
+
483
+ x = self.norm_conv(residual)
484
+ x = self.conv(x, pad_mask=pad_mask)
485
+ residual = residual + x
486
+
487
+ x = self.norm_feed_forward2(residual)
488
+ x = self.feed_forward2(x)
489
+ residual = residual + x * self.fc_factor
490
+
491
+ x = self.norm_out(residual)
492
+ return x
493
+
494
+
495
+ class ConformerEncoder(nn.Module):
496
+ """
497
+ Conformer Encoder module.
498
+ This module encapsulates the entire Conformer encoder architecture,
499
+ consisting of a StridingSubsampling layer, positional embeddings, and
500
+ a stack of Conformer Layers.
501
+ It serves as the main component responsible for processing speech features.
502
+ """
503
+
504
+ def __init__(
505
+ self,
506
+ feat_in: int = 64,
507
+ n_layers: int = 16,
508
+ d_model: int = 768,
509
+ subsampling_factor: int = 4,
510
+ ff_expansion_factor: int = 4,
511
+ self_attention_model: str = "rotary",
512
+ n_heads: int = 16,
513
+ pos_emb_max_len: int = 5000,
514
+ conv_kernel_size: int = 31,
515
+ flash_attn: bool = False,
516
+ ):
517
+ super().__init__()
518
+ self.feat_in = feat_in
519
+ assert self_attention_model in [
520
+ "rotary",
521
+ "rel_pos",
522
+ ], f"Not supported attn = {self_attention_model}"
523
+
524
+ self.pre_encode = StridingSubsampling(
525
+ subsampling_factor=subsampling_factor,
526
+ feat_in=feat_in,
527
+ feat_out=d_model,
528
+ conv_channels=d_model,
529
+ )
530
+
531
+ if self_attention_model == "rotary":
532
+ self.pos_enc: nn.Module = RotaryPositionalEmbedding(
533
+ d_model // n_heads, pos_emb_max_len
534
+ )
535
+ else:
536
+ self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len)
537
+
538
+ self.layers = nn.ModuleList()
539
+ for _ in range(n_layers):
540
+ layer = ConformerLayer(
541
+ d_model=d_model,
542
+ d_ff=d_model * ff_expansion_factor,
543
+ self_attention_model=self_attention_model,
544
+ n_heads=n_heads,
545
+ conv_kernel_size=conv_kernel_size,
546
+ flash_attn=flash_attn,
547
+ )
548
+ self.layers.append(layer)
549
+
550
+ self.pos_enc.extend_pe(pos_emb_max_len, next(self.parameters()).device)
551
+
552
+ def input_example(
553
+ self,
554
+ batch_size: int = 1,
555
+ seqlen: int = 200,
556
+ ):
557
+ device = next(self.parameters()).device
558
+ features = torch.zeros(batch_size, self.feat_in, seqlen)
559
+ feature_lengths = torch.full([batch_size], features.shape[-1])
560
+ return features.float().to(device), feature_lengths.to(device)
561
+
562
+ def input_names(self):
563
+ return ["audio_signal", "length"]
564
+
565
+ def output_names(self):
566
+ return ["encoded", "encoded_len"]
567
+
568
+ def dynamic_axes(self):
569
+ return {
570
+ "audio_signal": {0: "batch_size", 2: "seq_len"},
571
+ "length": {0: "batch_size"},
572
+ "encoded": {0: "batch_size", 1: "seq_len"},
573
+ "encoded_len": {0: "batch_size"},
574
+ }
575
+
576
+ def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]:
577
+ audio_signal, length = self.pre_encode(
578
+ x=audio_signal.transpose(1, 2), lengths=length
579
+ )
580
+
581
+ max_len = audio_signal.size(1)
582
+ audio_signal, pos_emb = self.pos_enc(x=audio_signal)
583
+
584
+ pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand(
585
+ length.size(0), -1
586
+ ) < length.unsqueeze(-1)
587
+
588
+ att_mask = None
589
+ if audio_signal.shape[0] > 1:
590
+ att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1])
591
+ att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2))
592
+ att_mask = ~att_mask
593
+
594
+ pad_mask = ~pad_mask
595
+
596
+ for layer in self.layers:
597
+ audio_signal = layer(
598
+ x=audio_signal,
599
+ pos_emb=pos_emb,
600
+ att_mask=att_mask,
601
+ pad_mask=pad_mask,
602
+ )
603
+
604
+ return audio_signal.transpose(1, 2), length
gigaam_transformers.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchaudio
7
+ from .encoder import ConformerEncoder
8
+ from torch import Tensor
9
+ from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.feature_extraction_sequence_utils import \
12
+ SequenceFeatureExtractor
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.modeling_outputs import CausalLMOutput, Seq2SeqLMOutput
15
+ from transformers.modeling_utils import PreTrainedModel
16
+
17
+
18
+ class GigaAMCTC(nn.Module):
19
+ """
20
+ GigaAM-CTC model
21
+ """
22
+
23
+ def __init__(self, config_encoder, config_head):
24
+ super().__init__()
25
+ self.encoder = ConformerEncoder(**config_encoder)
26
+ self.head = CTCHead(**config_head)
27
+
28
+ def forward(self, input_features: Tensor, input_lengths: Tensor) -> Tensor:
29
+ encoded, encoded_lengths = self.encoder(input_features, input_lengths)
30
+ logits = self.head(encoded)
31
+ return logits, encoded_lengths
32
+
33
+
34
+ class GigaAMRNNT(nn.Module):
35
+ """
36
+ GigaAM-RNNT model
37
+ """
38
+
39
+ def __init__(self, config_encoder, config_head):
40
+ super().__init__()
41
+ self.encoder = ConformerEncoder(**config_encoder)
42
+ self.head = RNNTHead(**config_head)
43
+
44
+ def forward(self, input_features: Tensor, input_lengths: Tensor, targets: Tensor, target_lengths: Tensor) -> Tensor:
45
+ encoded, encoded_lengths = self.encoder(input_features, input_lengths)
46
+ # During training, loss must be computed, so decoder forward is necessary
47
+ decoder_out, target_lengths, states = self.head.decoder(targets=targets, target_length=target_lengths)
48
+ joint = self.head.joint(encoder_outputs=encoded, decoder_outputs=decoder_out)
49
+ # loss = self.loss(
50
+ # log_probs=joint, targets=targets, input_lengths=encoded_lengths, target_lengths=target_lengths
51
+ # )
52
+
53
+ return joint, encoded_lengths
54
+
55
+
56
+ class CTCHead(nn.Module):
57
+ """
58
+ CTC Head module for Connectionist Temporal Classification.
59
+ """
60
+
61
+ def __init__(self, feat_in: int, num_classes: int):
62
+ super().__init__()
63
+ self.decoder_layers = nn.Sequential(
64
+ nn.Conv1d(feat_in, num_classes, kernel_size=1)
65
+ )
66
+
67
+ def forward(self, encoder_output: Tensor) -> Tensor:
68
+ # B x C x T
69
+ return self.decoder_layers(encoder_output)
70
+
71
+
72
+ class RNNTJoint(nn.Module):
73
+ """
74
+ RNN-Transducer Joint Network Module.
75
+ This module combines the outputs of the encoder and the prediction network using
76
+ a linear transformation followed by ReLU activation and another linear projection.
77
+ """
78
+
79
+ def __init__(
80
+ self, enc_hidden: int, pred_hidden: int, joint_hidden: int, num_classes: int
81
+ ):
82
+ super().__init__()
83
+ self.enc_hidden = enc_hidden
84
+ self.pred_hidden = pred_hidden
85
+ self.pred = nn.Linear(pred_hidden, joint_hidden)
86
+ self.enc = nn.Linear(enc_hidden, joint_hidden)
87
+ self.joint_net = nn.Sequential(nn.ReLU(), nn.Linear(joint_hidden, num_classes))
88
+
89
+ def joint(self, encoder_out: Tensor, decoder_out: Tensor) -> Tensor:
90
+ """
91
+ Combine the encoder and prediction network outputs into a joint representation.
92
+ """
93
+ enc = self.enc(encoder_out).unsqueeze(2)
94
+ pred = self.pred(decoder_out).unsqueeze(1)
95
+ return self.joint_net(enc + pred)
96
+
97
+ def input_example(self):
98
+ device = next(self.parameters()).device
99
+ enc = torch.zeros(1, self.enc_hidden, 1)
100
+ dec = torch.zeros(1, self.pred_hidden, 1)
101
+ return enc.float().to(device), dec.float().to(device)
102
+
103
+ def input_names(self):
104
+ return ["enc", "dec"]
105
+
106
+ def output_names(self):
107
+ return ["joint"]
108
+
109
+ def forward(self, enc: Tensor, dec: Tensor) -> Tensor:
110
+ return self.joint(enc.transpose(1, 2), dec.transpose(1, 2))
111
+
112
+
113
+ class RNNTDecoder(nn.Module):
114
+ """
115
+ RNN-Transducer Decoder Module.
116
+ This module handles the prediction network part of the RNN-Transducer architecture.
117
+ """
118
+
119
+ def __init__(self, pred_hidden: int, pred_rnn_layers: int, num_classes: int):
120
+ super().__init__()
121
+ self.blank_id = num_classes - 1
122
+ self.pred_hidden = pred_hidden
123
+ self.embed = nn.Embedding(num_classes, pred_hidden, padding_idx=self.blank_id)
124
+ self.lstm = nn.LSTM(pred_hidden, pred_hidden, pred_rnn_layers)
125
+
126
+ def predict(
127
+ self,
128
+ x: Optional[Tensor],
129
+ state: Optional[Tensor],
130
+ batch_size: int = 1,
131
+ ) -> Tuple[Tensor, Tensor]:
132
+ """
133
+ Make predictions based on the current input and previous states.
134
+ If no input is provided, use zeros as the initial input.
135
+ """
136
+ if x is not None:
137
+ emb: Tensor = self.embed(x)
138
+ else:
139
+ emb = torch.zeros(
140
+ (batch_size, 1, self.pred_hidden), device=next(self.parameters()).device
141
+ )
142
+ g, hid = self.lstm(emb.transpose(0, 1), state)
143
+ return g.transpose(0, 1), hid
144
+
145
+ def input_example(self):
146
+ device = next(self.parameters()).device
147
+ label = torch.tensor([[0]]).to(device)
148
+ hidden_h = torch.zeros(1, 1, self.pred_hidden).to(device)
149
+ hidden_c = torch.zeros(1, 1, self.pred_hidden).to(device)
150
+ return label, hidden_h, hidden_c
151
+
152
+ def input_names(self):
153
+ return ["x", "h", "c"]
154
+
155
+ def output_names(self):
156
+ return ["dec", "h", "c"]
157
+
158
+ def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
159
+ """
160
+ ONNX-specific forward with x, state = (h, c) -> x, h, c.
161
+ """
162
+ emb = self.embed(x)
163
+ g, (h, c) = self.lstm(emb.transpose(0, 1), (h, c))
164
+ return g.transpose(0, 1), h, c
165
+
166
+
167
+ class RNNTHead(nn.Module):
168
+ """
169
+ RNN-Transducer Head Module.
170
+ This module combines the decoder and joint network components of the RNN-Transducer architecture.
171
+ """
172
+
173
+ def __init__(self, decoder: Dict[str, int], joint: Dict[str, int]):
174
+ super().__init__()
175
+ self.decoder = RNNTDecoder(**decoder)
176
+ self.joint = RNNTJoint(**joint)
177
+
178
+
179
+ class GigaAMFeatureExtractor(SequenceFeatureExtractor):
180
+ """
181
+ Feature extractor for GigaAM.
182
+ """
183
+ model_input_names = ["input_features"]
184
+
185
+ def __init__(
186
+ self,
187
+ feature_size=64,
188
+ sampling_rate=16000,
189
+ padding_value=0.0,
190
+ chunk_length=30.0,
191
+ **kwargs,
192
+ ):
193
+ super().__init__(
194
+ feature_size=feature_size,
195
+ sampling_rate=sampling_rate,
196
+ padding_value=padding_value,
197
+ chunk_length=chunk_length,
198
+ **kwargs,
199
+ )
200
+ self.hop_length = sampling_rate // 100
201
+ self.n_samples = chunk_length * sampling_rate
202
+ self.featurizer = torchaudio.transforms.MelSpectrogram(
203
+ sample_rate=sampling_rate,
204
+ n_fft=sampling_rate // 40,
205
+ win_length=sampling_rate // 40,
206
+ hop_length=self.hop_length,
207
+ n_mels=feature_size,
208
+ )
209
+
210
+ def to_dict(self) -> Dict[str, Union[str, int, Dict]]:
211
+ dictionary = super().to_dict()
212
+
213
+ if "featurizer" in dictionary:
214
+ del dictionary["featurizer"]
215
+ dictionary["hop_length"] = self.hop_length
216
+ dictionary["n_samples"] = self.n_samples
217
+ return dictionary
218
+
219
+ def out_len(self, input_lengths: Tensor) -> Tensor:
220
+ """
221
+ Calculates the output length after the feature extraction process.
222
+ """
223
+ return input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
224
+
225
+ def __call__(
226
+ self,
227
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
228
+ sampling_rate: Optional[int] = None,
229
+ padding: str = "max_length",
230
+ **kwargs,
231
+ ):
232
+ is_batched_numpy = (
233
+ isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
234
+ )
235
+ if is_batched_numpy and len(raw_speech.shape) > 2:
236
+ raise ValueError(
237
+ f"Only mono-channel audio is supported for input to {self}"
238
+ )
239
+ is_batched = is_batched_numpy or (
240
+ isinstance(raw_speech, (list, tuple))
241
+ and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
242
+ )
243
+
244
+ if is_batched:
245
+ raw_speech = [
246
+ np.asarray([speech], dtype=np.float32).T for speech in raw_speech
247
+ ]
248
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
249
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
250
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
251
+ np.float64
252
+ ):
253
+ raw_speech = raw_speech.astype(np.float32)
254
+
255
+ # always return batch
256
+ if not is_batched:
257
+ raw_speech = [np.asarray([raw_speech]).T]
258
+
259
+ input_lengths = torch.tensor([len(speech) for speech in raw_speech])
260
+
261
+ batched_speech = BatchFeature({"input_features": raw_speech})
262
+
263
+ padded_inputs = self.pad(
264
+ batched_speech,
265
+ padding=padding,
266
+ max_length=self.n_samples,
267
+ truncation=False,
268
+ return_tensors="pt",
269
+ )
270
+
271
+ input_features = padded_inputs["input_features"].transpose(1, 2)
272
+ input_features = self.featurizer(input_features).squeeze(1)
273
+ input_features = torch.log(input_features.clamp_(1e-9, 1e9))
274
+ input_lengths = self.out_len(input_lengths)
275
+
276
+ return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
277
+
278
+
279
+ class GigaAMCTCTokenizer(Wav2Vec2CTCTokenizer):
280
+ """
281
+ Char tokenizer for GigaAM-CTC model.
282
+ """
283
+ def __init__(
284
+ self,
285
+ vocab_file,
286
+ unk_token="[BLANK]",
287
+ pad_token="[BLANK]",
288
+ bos_token=None,
289
+ eos_token=None,
290
+ word_delimiter_token=" ",
291
+ **kwargs,
292
+ ):
293
+ super().__init__(
294
+ vocab_file=vocab_file,
295
+ unk_token=unk_token,
296
+ pad_token=pad_token,
297
+ bos_token=bos_token,
298
+ eos_token=eos_token,
299
+ word_delimiter_token=word_delimiter_token,
300
+ **kwargs,
301
+ )
302
+
303
+
304
+ class GigaAMProcessor(Wav2Vec2Processor):
305
+ feature_extractor_class = "GigaAMFeatureExtractor"
306
+ tokenizer_class = "GigaAMCTCTokenizer"
307
+
308
+ def __init__(self, feature_extractor, tokenizer):
309
+ # super().__init__(feature_extractor, tokenizer)
310
+ self.feature_extractor = feature_extractor
311
+ self.tokenizer = tokenizer
312
+ self.current_processor = self.feature_extractor
313
+ self._in_target_context_manager = False
314
+
315
+ @classmethod
316
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
317
+ feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
318
+ tokenizer = GigaAMCTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
319
+
320
+ return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
321
+
322
+
323
+ class GigaAMConfig(PretrainedConfig):
324
+ model_type = "gigaam-ctc"
325
+
326
+ def __init__(self, **kwargs):
327
+ super().__init__(**kwargs)
328
+
329
+
330
+ class GigaAMCTCHF(PreTrainedModel):
331
+ """
332
+ GigaAM-CTC model for transformers
333
+ """
334
+ config_class = GigaAMConfig
335
+ base_model_prefix = "gigaamctc"
336
+ main_input_name = "input_features"
337
+
338
+ def __init__(self, config: GigaAMConfig):
339
+ super().__init__(config)
340
+ self.model = GigaAMCTC(config.encoder, config.head)
341
+
342
+ def forward(self, input_features, input_lengths, labels=None, **kwargs):
343
+
344
+ # B x C x T
345
+ logits, encoded_lengths = self.model(input_features, input_lengths)
346
+ # B x C x T -> B x T x C -> T x B x C
347
+ log_probs = torch.log_softmax(
348
+ logits.transpose(1, 2), dim=-1, dtype=torch.float32
349
+ ).transpose(0, 1)
350
+
351
+ loss = None
352
+ if labels is not None:
353
+ labels_mask = labels >= 0
354
+ target_lengths = labels_mask.sum(-1)
355
+ flattened_targets = labels.masked_select(labels_mask)
356
+
357
+ loss = nn.functional.ctc_loss(
358
+ log_probs,
359
+ flattened_targets,
360
+ encoded_lengths,
361
+ target_lengths,
362
+ blank=self.config.blank_id,
363
+ zero_infinity=True,
364
+ )
365
+
366
+ return CausalLMOutput(loss=loss, logits=logits.transpose(1, 2))
367
+
368
+
369
+ class GigaAMRNNTHF(PreTrainedModel):
370
+ """
371
+ GigaAM-RNNT model for transformers
372
+ """
373
+ config_class = GigaAMConfig
374
+ base_model_prefix = "gigaamrnnt"
375
+ main_input_name = "input_features"
376
+
377
+ def __init__(self, config: GigaAMConfig):
378
+ super().__init__(config)
379
+ self.model = GigaAMRNNT(config.encoder, config.head)
380
+
381
+ def forward(self, input_features, input_lengths, labels=None, **kwargs):
382
+
383
+ # B x C x T
384
+ encoder_out, encoded_lengths = self.model.encoder(input_features, input_lengths)
385
+ encoder_out = encoder_out.transpose(1, 2)
386
+ batch_size = encoder_out.shape[0]
387
+
388
+ loss = None
389
+ if labels is not None:
390
+ labels = labels.to(torch.int32)
391
+ labels_mask = labels >= 0
392
+ target_lengths = labels_mask.sum(-1).to(torch.int32)
393
+
394
+ hidden_states = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
395
+ hidden_c = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
396
+ plus_one_dim = self.config.blank_id * torch.ones((batch_size, 1), dtype=torch.int32, device=encoder_out.device)
397
+
398
+ decoder_out, h, c = self.model.head.decoder(torch.cat((plus_one_dim, labels), dim=1), hidden_states, hidden_c)
399
+
400
+ joint = self.model.head.joint.joint(encoder_out, decoder_out)
401
+ loss = torchaudio.functional.rnnt_loss(
402
+ logits=joint,
403
+ targets=labels,
404
+ logit_lengths=encoded_lengths,
405
+ target_lengths=target_lengths,
406
+ blank=self.config.blank_id,
407
+ )
408
+
409
+ return Seq2SeqLMOutput(loss=loss, logits=encoder_out.transpose(1, 2))
410
+
411
+ def _greedy_decode(self, x: Tensor, seqlen: Tensor) -> str:
412
+ """
413
+ Internal helper function for performing greedy decoding on a single sequence.
414
+ """
415
+ hyp: List[int] = []
416
+ dec_state: Optional[Tensor] = None
417
+ last_label: Optional[Tensor] = None
418
+ for t in range(seqlen):
419
+ f = x[t, :, :].unsqueeze(1)
420
+ not_blank = True
421
+ new_symbols = 0
422
+ while not_blank and new_symbols < self.config.max_symbols:
423
+ g, hidden = self.model.head.decoder.predict(last_label, dec_state)
424
+ k = self.model.head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item()
425
+ if k == self.config.blank_id:
426
+ not_blank = False
427
+ else:
428
+ hyp.append(k)
429
+ dec_state = hidden
430
+ last_label = torch.tensor([[hyp[-1]]]).to(x.device)
431
+ new_symbols += 1
432
+
433
+ return torch.tensor([hyp], dtype=torch.int32)
434
+
435
+ @torch.inference_mode()
436
+ def generate(self, input_features: Tensor, input_lengths: Tensor, **kwargs) -> torch.Tensor:
437
+ """
438
+ Decode the output of an RNN-T model into a list of hypotheses.
439
+ """
440
+ encoder_out, encoded_lengths = self.model.encoder(input_features, input_lengths)
441
+ encoder_out = encoder_out.transpose(1, 2)
442
+ b = encoder_out.shape[0]
443
+ preds = []
444
+ for i in range(b):
445
+ inseq = encoder_out[i, :, :].unsqueeze(1)
446
+ preds.append(self._greedy_decode(inseq, encoded_lengths[i]))
447
+ return torch.cat(preds, dim=1)
preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "GigaAMFeatureExtractor",
4
+ "feature_extractor_class": "GigaAMFeatureExtractor",
5
+ "feature_size": 64,
6
+ "hop_length": 160,
7
+ "n_samples": 480000,
8
+ "padding_side": "right",
9
+ "padding_value": 0.0,
10
+ "return_attention_mask": true,
11
+ "sampling_rate": 16000,
12
+ "auto_map": {
13
+ "AutoFeatureExtractor": "gigaam_transformers.GigaAMFeatureExtractor",
14
+ "AutoProcessor": "gigaam_transformers.GigaAMProcessor"
15
+ },
16
+ "processor_class": "GigaAMProcessor",
17
+ "model_type": "gigaam-rnnt"
18
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16888332b279f5296b411f5bbe5385f20448d261b2c9a6425e1a83026efb2018
3
+ size 935306114
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "pad_token": "[BLANK]",
3
+ "unk_token": "[BLANK]"
4
+ }
5
+
tokenizer_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "33": {
4
+ "content": "[BLANK]",
5
+ "lstrip": true,
6
+ "normalized": false,
7
+ "rstrip": true,
8
+ "single_word": false,
9
+ "special": false
10
+ }
11
+ },
12
+ "bos_token": null,
13
+ "clean_up_tokenization_spaces": false,
14
+ "do_lower_case": false,
15
+ "eos_token": null,
16
+ "model_max_length": 1000,
17
+ "pad_token": "[BLANK]",
18
+ "replace_word_delimiter_char": " ",
19
+ "target_lang": null,
20
+ "tokenizer_class": "GigaAMTokenizer",
21
+ "unk_token": "[BLANK]",
22
+ "word_delimiter_token": " ",
23
+ "auto_map": {
24
+ "AutoTokenizer": ["gigaam_transformers.GigaAMTokenizer", null]
25
+ }
26
+ }
27
+
vocab.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ " ": 0,
3
+ "[BLANK]": 33,
4
+ "а": 1,
5
+ "б": 2,
6
+ "в": 3,
7
+ "г": 4,
8
+ "д": 5,
9
+ "е": 6,
10
+ "ж": 7,
11
+ "з": 8,
12
+ "и": 9,
13
+ "й": 10,
14
+ "к": 11,
15
+ "л": 12,
16
+ "м": 13,
17
+ "н": 14,
18
+ "о": 15,
19
+ "п": 16,
20
+ "р": 17,
21
+ "с": 18,
22
+ "т": 19,
23
+ "у": 20,
24
+ "ф": 21,
25
+ "х": 22,
26
+ "ц": 23,
27
+ "ч": 24,
28
+ "ш": 25,
29
+ "щ": 26,
30
+ "ъ": 27,
31
+ "ы": 28,
32
+ "ь": 29,
33
+ "э": 30,
34
+ "ю": 31,
35
+ "я": 32
36
+ }
37
+