waveletdeboshir commited on
Commit
9a323c7
·
verified ·
1 Parent(s): 6214cc3

Upload gigaam_transformers.py

Browse files
Files changed (1) hide show
  1. gigaam_transformers.py +235 -0
gigaam_transformers.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchaudio
7
+ from gigaam.encoder import ConformerEncoder
8
+ from torch import Tensor
9
+ from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.feature_extraction_sequence_utils import \
12
+ SequenceFeatureExtractor
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.modeling_outputs import CausalLMOutput
15
+ from transformers.modeling_utils import PreTrainedModel
16
+
17
+
18
+ class GigaAMCTC(nn.Module):
19
+ """
20
+ GigaAM-CTC model
21
+ """
22
+
23
+ def __init__(self, config_encoder, config_head):
24
+ super().__init__()
25
+ self.encoder = ConformerEncoder(**config_encoder)
26
+ self.head = CTCHead(**config_head)
27
+
28
+ def forward(self, input_features: Tensor, input_lengths: Tensor) -> Tensor:
29
+ encoded, encoded_lengths = self.encoder(input_features, input_lengths)
30
+ logits = self.head(encoded)
31
+ return logits, encoded_lengths
32
+
33
+
34
+ class CTCHead(nn.Module):
35
+ """
36
+ CTC Head module for Connectionist Temporal Classification.
37
+ """
38
+
39
+ def __init__(self, feat_in: int, num_classes: int):
40
+ super().__init__()
41
+ self.decoder_layers = nn.Sequential(
42
+ nn.Conv1d(feat_in, num_classes, kernel_size=1)
43
+ )
44
+
45
+ def forward(self, encoder_output: Tensor) -> Tensor:
46
+ # B x C x T
47
+ return self.decoder_layers(encoder_output)
48
+
49
+
50
+ class GigaAMFeatureExtractor(SequenceFeatureExtractor):
51
+ """
52
+ Feature extractor for GigaAM.
53
+ """
54
+ model_input_names = ["input_features"]
55
+
56
+ def __init__(
57
+ self,
58
+ feature_size=64,
59
+ sampling_rate=16000,
60
+ padding_value=0.0,
61
+ chunk_length=30.0,
62
+ **kwargs,
63
+ ):
64
+ super().__init__(
65
+ feature_size=feature_size,
66
+ sampling_rate=sampling_rate,
67
+ padding_value=padding_value,
68
+ chunk_length=chunk_length,
69
+ **kwargs,
70
+ )
71
+ self.hop_length = sampling_rate // 100
72
+ self.n_samples = chunk_length * sampling_rate
73
+ self.featurizer = torchaudio.transforms.MelSpectrogram(
74
+ sample_rate=sampling_rate,
75
+ n_fft=sampling_rate // 40,
76
+ win_length=sampling_rate // 40,
77
+ hop_length=self.hop_length,
78
+ n_mels=feature_size,
79
+ )
80
+
81
+ def to_dict(self) -> Dict[str, Union[str, int, Dict]]:
82
+ dictionary = super().to_dict()
83
+
84
+ if "featurizer" in dictionary:
85
+ del dictionary["featurizer"]
86
+ dictionary["hop_length"] = self.hop_length
87
+ dictionary["n_samples"] = self.n_samples
88
+ return dictionary
89
+
90
+ def out_len(self, input_lengths: Tensor) -> Tensor:
91
+ """
92
+ Calculates the output length after the feature extraction process.
93
+ """
94
+ return input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
95
+
96
+ def __call__(
97
+ self,
98
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
99
+ sampling_rate: Optional[int] = None,
100
+ padding: str = "max_length",
101
+ **kwargs,
102
+ ):
103
+ is_batched_numpy = (
104
+ isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
105
+ )
106
+ if is_batched_numpy and len(raw_speech.shape) > 2:
107
+ raise ValueError(
108
+ f"Only mono-channel audio is supported for input to {self}"
109
+ )
110
+ is_batched = is_batched_numpy or (
111
+ isinstance(raw_speech, (list, tuple))
112
+ and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
113
+ )
114
+
115
+ if is_batched:
116
+ raw_speech = [
117
+ np.asarray([speech], dtype=np.float32).T for speech in raw_speech
118
+ ]
119
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
120
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
121
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
122
+ np.float64
123
+ ):
124
+ raw_speech = raw_speech.astype(np.float32)
125
+
126
+ # always return batch
127
+ if not is_batched:
128
+ raw_speech = [np.asarray([raw_speech]).T]
129
+
130
+ input_lengths = torch.tensor([len(speech) for speech in raw_speech])
131
+
132
+ batched_speech = BatchFeature({"input_features": raw_speech})
133
+
134
+ padded_inputs = self.pad(
135
+ batched_speech,
136
+ padding=padding,
137
+ max_length=self.n_samples,
138
+ truncation=False,
139
+ return_tensors="pt",
140
+ )
141
+
142
+ input_features = padded_inputs["input_features"].transpose(1, 2)
143
+ input_features = self.featurizer(input_features).squeeze(1)
144
+ input_features = torch.log(input_features.clamp_(1e-9, 1e9))
145
+ input_lengths = self.out_len(input_lengths)
146
+
147
+ return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
148
+
149
+
150
+ class GigaAMCTCTokenizer(Wav2Vec2CTCTokenizer):
151
+ """
152
+ Char tokenizer for GigaAM-CTC model.
153
+ """
154
+ def __init__(
155
+ self,
156
+ vocab_file,
157
+ unk_token="[BLANK]",
158
+ pad_token="[BLANK]",
159
+ bos_token=None,
160
+ eos_token=None,
161
+ word_delimiter_token=" ",
162
+ **kwargs,
163
+ ):
164
+ super().__init__(
165
+ vocab_file=vocab_file,
166
+ unk_token=unk_token,
167
+ pad_token=pad_token,
168
+ bos_token=bos_token,
169
+ eos_token=eos_token,
170
+ word_delimiter_token=word_delimiter_token,
171
+ **kwargs,
172
+ )
173
+
174
+
175
+ class GigaAMProcessor(Wav2Vec2Processor):
176
+ feature_extractor_class = "GigaAMFeatureExtractor"
177
+ tokenizer_class = "GigaAMCTCTokenizer"
178
+
179
+ def __init__(self, feature_extractor, tokenizer):
180
+ # super().__init__(feature_extractor, tokenizer)
181
+ self.feature_extractor = feature_extractor
182
+ self.tokenizer = tokenizer
183
+ self.current_processor = self.feature_extractor
184
+ self._in_target_context_manager = False
185
+
186
+ @classmethod
187
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
188
+ feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
189
+ tokenizer = GigaAMCTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
190
+
191
+ return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
192
+
193
+
194
+ class GigaAMConfig(PretrainedConfig):
195
+ def __init__(self, **kwargs):
196
+ super().__init__(**kwargs)
197
+
198
+
199
+ class GigaAMCTCHF(PreTrainedModel):
200
+ """
201
+ GigaAM-CTC model for transformers
202
+ """
203
+ config_class = GigaAMConfig
204
+ base_model_prefix = "gigaamctc"
205
+ main_input_name = "input_features"
206
+
207
+ def __init__(self, config: GigaAMConfig):
208
+ super().__init__(config)
209
+ self.model = GigaAMCTC(config.encoder, config.head)
210
+
211
+ def forward(self, input_features, input_lengths, labels=None, **kwargs):
212
+
213
+ # B x C x T
214
+ logits, encoded_lengths = self.model(input_features, input_lengths)
215
+ # B x C x T -> B x T x C -> T x B x C
216
+ log_probs = torch.log_softmax(
217
+ logits.transpose(1, 2), dim=-1, dtype=torch.float32
218
+ ).transpose(0, 1)
219
+
220
+ loss = None
221
+ if labels is not None:
222
+ labels_mask = labels >= 0
223
+ target_lengths = labels_mask.sum(-1)
224
+ flattened_targets = labels.masked_select(labels_mask)
225
+
226
+ loss = nn.functional.ctc_loss(
227
+ log_probs,
228
+ flattened_targets,
229
+ encoded_lengths,
230
+ target_lengths,
231
+ blank=self.config.blank_id,
232
+ zero_infinity=True,
233
+ )
234
+
235
+ return CausalLMOutput(loss=loss, logits=logits.transpose(1, 2))