Upload gigaam_transformers.py
Browse files- gigaam_transformers.py +235 -0
gigaam_transformers.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchaudio
|
7 |
+
from gigaam.encoder import ConformerEncoder
|
8 |
+
from torch import Tensor
|
9 |
+
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
|
10 |
+
from transformers.configuration_utils import PretrainedConfig
|
11 |
+
from transformers.feature_extraction_sequence_utils import \
|
12 |
+
SequenceFeatureExtractor
|
13 |
+
from transformers.feature_extraction_utils import BatchFeature
|
14 |
+
from transformers.modeling_outputs import CausalLMOutput
|
15 |
+
from transformers.modeling_utils import PreTrainedModel
|
16 |
+
|
17 |
+
|
18 |
+
class GigaAMCTC(nn.Module):
|
19 |
+
"""
|
20 |
+
GigaAM-CTC model
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, config_encoder, config_head):
|
24 |
+
super().__init__()
|
25 |
+
self.encoder = ConformerEncoder(**config_encoder)
|
26 |
+
self.head = CTCHead(**config_head)
|
27 |
+
|
28 |
+
def forward(self, input_features: Tensor, input_lengths: Tensor) -> Tensor:
|
29 |
+
encoded, encoded_lengths = self.encoder(input_features, input_lengths)
|
30 |
+
logits = self.head(encoded)
|
31 |
+
return logits, encoded_lengths
|
32 |
+
|
33 |
+
|
34 |
+
class CTCHead(nn.Module):
|
35 |
+
"""
|
36 |
+
CTC Head module for Connectionist Temporal Classification.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, feat_in: int, num_classes: int):
|
40 |
+
super().__init__()
|
41 |
+
self.decoder_layers = nn.Sequential(
|
42 |
+
nn.Conv1d(feat_in, num_classes, kernel_size=1)
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, encoder_output: Tensor) -> Tensor:
|
46 |
+
# B x C x T
|
47 |
+
return self.decoder_layers(encoder_output)
|
48 |
+
|
49 |
+
|
50 |
+
class GigaAMFeatureExtractor(SequenceFeatureExtractor):
|
51 |
+
"""
|
52 |
+
Feature extractor for GigaAM.
|
53 |
+
"""
|
54 |
+
model_input_names = ["input_features"]
|
55 |
+
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
feature_size=64,
|
59 |
+
sampling_rate=16000,
|
60 |
+
padding_value=0.0,
|
61 |
+
chunk_length=30.0,
|
62 |
+
**kwargs,
|
63 |
+
):
|
64 |
+
super().__init__(
|
65 |
+
feature_size=feature_size,
|
66 |
+
sampling_rate=sampling_rate,
|
67 |
+
padding_value=padding_value,
|
68 |
+
chunk_length=chunk_length,
|
69 |
+
**kwargs,
|
70 |
+
)
|
71 |
+
self.hop_length = sampling_rate // 100
|
72 |
+
self.n_samples = chunk_length * sampling_rate
|
73 |
+
self.featurizer = torchaudio.transforms.MelSpectrogram(
|
74 |
+
sample_rate=sampling_rate,
|
75 |
+
n_fft=sampling_rate // 40,
|
76 |
+
win_length=sampling_rate // 40,
|
77 |
+
hop_length=self.hop_length,
|
78 |
+
n_mels=feature_size,
|
79 |
+
)
|
80 |
+
|
81 |
+
def to_dict(self) -> Dict[str, Union[str, int, Dict]]:
|
82 |
+
dictionary = super().to_dict()
|
83 |
+
|
84 |
+
if "featurizer" in dictionary:
|
85 |
+
del dictionary["featurizer"]
|
86 |
+
dictionary["hop_length"] = self.hop_length
|
87 |
+
dictionary["n_samples"] = self.n_samples
|
88 |
+
return dictionary
|
89 |
+
|
90 |
+
def out_len(self, input_lengths: Tensor) -> Tensor:
|
91 |
+
"""
|
92 |
+
Calculates the output length after the feature extraction process.
|
93 |
+
"""
|
94 |
+
return input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
|
95 |
+
|
96 |
+
def __call__(
|
97 |
+
self,
|
98 |
+
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
99 |
+
sampling_rate: Optional[int] = None,
|
100 |
+
padding: str = "max_length",
|
101 |
+
**kwargs,
|
102 |
+
):
|
103 |
+
is_batched_numpy = (
|
104 |
+
isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
105 |
+
)
|
106 |
+
if is_batched_numpy and len(raw_speech.shape) > 2:
|
107 |
+
raise ValueError(
|
108 |
+
f"Only mono-channel audio is supported for input to {self}"
|
109 |
+
)
|
110 |
+
is_batched = is_batched_numpy or (
|
111 |
+
isinstance(raw_speech, (list, tuple))
|
112 |
+
and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
113 |
+
)
|
114 |
+
|
115 |
+
if is_batched:
|
116 |
+
raw_speech = [
|
117 |
+
np.asarray([speech], dtype=np.float32).T for speech in raw_speech
|
118 |
+
]
|
119 |
+
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
120 |
+
raw_speech = np.asarray(raw_speech, dtype=np.float32)
|
121 |
+
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
|
122 |
+
np.float64
|
123 |
+
):
|
124 |
+
raw_speech = raw_speech.astype(np.float32)
|
125 |
+
|
126 |
+
# always return batch
|
127 |
+
if not is_batched:
|
128 |
+
raw_speech = [np.asarray([raw_speech]).T]
|
129 |
+
|
130 |
+
input_lengths = torch.tensor([len(speech) for speech in raw_speech])
|
131 |
+
|
132 |
+
batched_speech = BatchFeature({"input_features": raw_speech})
|
133 |
+
|
134 |
+
padded_inputs = self.pad(
|
135 |
+
batched_speech,
|
136 |
+
padding=padding,
|
137 |
+
max_length=self.n_samples,
|
138 |
+
truncation=False,
|
139 |
+
return_tensors="pt",
|
140 |
+
)
|
141 |
+
|
142 |
+
input_features = padded_inputs["input_features"].transpose(1, 2)
|
143 |
+
input_features = self.featurizer(input_features).squeeze(1)
|
144 |
+
input_features = torch.log(input_features.clamp_(1e-9, 1e9))
|
145 |
+
input_lengths = self.out_len(input_lengths)
|
146 |
+
|
147 |
+
return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
|
148 |
+
|
149 |
+
|
150 |
+
class GigaAMCTCTokenizer(Wav2Vec2CTCTokenizer):
|
151 |
+
"""
|
152 |
+
Char tokenizer for GigaAM-CTC model.
|
153 |
+
"""
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
vocab_file,
|
157 |
+
unk_token="[BLANK]",
|
158 |
+
pad_token="[BLANK]",
|
159 |
+
bos_token=None,
|
160 |
+
eos_token=None,
|
161 |
+
word_delimiter_token=" ",
|
162 |
+
**kwargs,
|
163 |
+
):
|
164 |
+
super().__init__(
|
165 |
+
vocab_file=vocab_file,
|
166 |
+
unk_token=unk_token,
|
167 |
+
pad_token=pad_token,
|
168 |
+
bos_token=bos_token,
|
169 |
+
eos_token=eos_token,
|
170 |
+
word_delimiter_token=word_delimiter_token,
|
171 |
+
**kwargs,
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
class GigaAMProcessor(Wav2Vec2Processor):
|
176 |
+
feature_extractor_class = "GigaAMFeatureExtractor"
|
177 |
+
tokenizer_class = "GigaAMCTCTokenizer"
|
178 |
+
|
179 |
+
def __init__(self, feature_extractor, tokenizer):
|
180 |
+
# super().__init__(feature_extractor, tokenizer)
|
181 |
+
self.feature_extractor = feature_extractor
|
182 |
+
self.tokenizer = tokenizer
|
183 |
+
self.current_processor = self.feature_extractor
|
184 |
+
self._in_target_context_manager = False
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
188 |
+
feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
189 |
+
tokenizer = GigaAMCTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
190 |
+
|
191 |
+
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
192 |
+
|
193 |
+
|
194 |
+
class GigaAMConfig(PretrainedConfig):
|
195 |
+
def __init__(self, **kwargs):
|
196 |
+
super().__init__(**kwargs)
|
197 |
+
|
198 |
+
|
199 |
+
class GigaAMCTCHF(PreTrainedModel):
|
200 |
+
"""
|
201 |
+
GigaAM-CTC model for transformers
|
202 |
+
"""
|
203 |
+
config_class = GigaAMConfig
|
204 |
+
base_model_prefix = "gigaamctc"
|
205 |
+
main_input_name = "input_features"
|
206 |
+
|
207 |
+
def __init__(self, config: GigaAMConfig):
|
208 |
+
super().__init__(config)
|
209 |
+
self.model = GigaAMCTC(config.encoder, config.head)
|
210 |
+
|
211 |
+
def forward(self, input_features, input_lengths, labels=None, **kwargs):
|
212 |
+
|
213 |
+
# B x C x T
|
214 |
+
logits, encoded_lengths = self.model(input_features, input_lengths)
|
215 |
+
# B x C x T -> B x T x C -> T x B x C
|
216 |
+
log_probs = torch.log_softmax(
|
217 |
+
logits.transpose(1, 2), dim=-1, dtype=torch.float32
|
218 |
+
).transpose(0, 1)
|
219 |
+
|
220 |
+
loss = None
|
221 |
+
if labels is not None:
|
222 |
+
labels_mask = labels >= 0
|
223 |
+
target_lengths = labels_mask.sum(-1)
|
224 |
+
flattened_targets = labels.masked_select(labels_mask)
|
225 |
+
|
226 |
+
loss = nn.functional.ctc_loss(
|
227 |
+
log_probs,
|
228 |
+
flattened_targets,
|
229 |
+
encoded_lengths,
|
230 |
+
target_lengths,
|
231 |
+
blank=self.config.blank_id,
|
232 |
+
zero_infinity=True,
|
233 |
+
)
|
234 |
+
|
235 |
+
return CausalLMOutput(loss=loss, logits=logits.transpose(1, 2))
|