Upload processing_granite_speech.py
Browse files- processing_granite_speech.py +96 -1
processing_granite_speech.py
CHANGED
@@ -33,8 +33,103 @@ logger = logging.get_logger(__name__)
|
|
33 |
# π¨π¨π¨ HACK π¨π¨π¨
|
34 |
# This is needed to avoid custom registration issues for now,
|
35 |
# since we have a custom subclass for the feature extractor as well.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
import transformers
|
37 |
-
from .feature_extraction_granite_speech import GraniteSpeechFeatureExtractor
|
38 |
transformers.GraniteSpeechFeatureExtractor = GraniteSpeechFeatureExtractor
|
39 |
# The above code is the only change in the modeling code from the following
|
40 |
# commit on Alex's fork: 397e03a4d76c5f3d8a651e47ade9f27c635e1617
|
|
|
33 |
# π¨π¨π¨ HACK π¨π¨π¨
|
34 |
# This is needed to avoid custom registration issues for now,
|
35 |
# since we have a custom subclass for the feature extractor as well.
|
36 |
+
import math
|
37 |
+
from typing import List, Optional
|
38 |
+
|
39 |
+
from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
40 |
+
from transformers.utils import is_torch_available, is_torchaudio_available, logging
|
41 |
+
|
42 |
+
if is_torch_available():
|
43 |
+
import torch
|
44 |
+
|
45 |
+
if is_torchaudio_available():
|
46 |
+
import torchaudio
|
47 |
+
|
48 |
+
|
49 |
+
class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
|
50 |
+
model_input_names = ["input_features"]
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
sampling_rate=16000,
|
55 |
+
n_fft=512,
|
56 |
+
win_length=400,
|
57 |
+
hop_length=160,
|
58 |
+
n_mels=80,
|
59 |
+
projector_window_size=15,
|
60 |
+
projector_downsample_rate=5,
|
61 |
+
**kwargs,
|
62 |
+
):
|
63 |
+
super().__init__(**kwargs)
|
64 |
+
self.melspec_kwargs = {
|
65 |
+
"sample_rate": sampling_rate,
|
66 |
+
"n_fft": n_fft,
|
67 |
+
"win_length": win_length,
|
68 |
+
"hop_length": hop_length,
|
69 |
+
"n_mels": n_mels,
|
70 |
+
}
|
71 |
+
# HACK - for now, lazily initialize the mel spectrogram transform;
|
72 |
+
# the feature extractor mixin explodes otherwise because
|
73 |
+
# it tries to log the feature extractor, and the melspectrogram
|
74 |
+
# transform isn't json serializable...
|
75 |
+
self.melspec = None
|
76 |
+
self.projector_window_size = projector_window_size
|
77 |
+
self.projector_downsample_rate = projector_downsample_rate
|
78 |
+
|
79 |
+
def _ensure_melspec_transform_is_initialized(self):
|
80 |
+
if self.melspec is None:
|
81 |
+
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
|
82 |
+
|
83 |
+
def __call__(
|
84 |
+
self,
|
85 |
+
x: torch.Tensor,
|
86 |
+
device: Optional[str] = "cpu",
|
87 |
+
) -> BatchFeature:
|
88 |
+
# TODO there is probably a better way to do both of these things...
|
89 |
+
self._ensure_melspec_transform_is_initialized()
|
90 |
+
if device is not None:
|
91 |
+
melspec = self.melspec.to(device)
|
92 |
+
x = x.to(device)
|
93 |
+
else:
|
94 |
+
melspec = self.melspec
|
95 |
+
|
96 |
+
B, _ = x.shape
|
97 |
+
with torch.no_grad():
|
98 |
+
mel = melspec(x.float())
|
99 |
+
logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_()
|
100 |
+
mx = logmel.amax(dim=(-2, -1), keepdim=True)
|
101 |
+
logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1)
|
102 |
+
if logmel.shape[1] % 2 == 1:
|
103 |
+
logmel = logmel[:, :-1] # remove last frame if odd
|
104 |
+
x = logmel.reshape(B, -1, 2 * logmel.shape[-1]) # stacking and skipping by 2
|
105 |
+
|
106 |
+
if x.device != "cpu":
|
107 |
+
return x.detach().cpu()
|
108 |
+
return x
|
109 |
+
|
110 |
+
def _get_num_audio_features(self, audio_lengths: List[int]) -> List[int]:
|
111 |
+
"""
|
112 |
+
Gets the (variable length) variable length number of features
|
113 |
+
(i.e., projector output) for the sequences being considered.
|
114 |
+
"""
|
115 |
+
hop_length = self.melspec_kwargs["hop_length"]
|
116 |
+
effective_window_size = self.projector_window_size // self.projector_downsample_rate
|
117 |
+
|
118 |
+
projector_lengths = []
|
119 |
+
for raw_length in audio_lengths:
|
120 |
+
# mel sequence length computation
|
121 |
+
mel_length = raw_length // hop_length + 1
|
122 |
+
# encoder frame takes two mel features
|
123 |
+
encoder_length = mel_length // 2
|
124 |
+
nblocks = math.ceil(encoder_length / self.projector_window_size)
|
125 |
+
# projector output length
|
126 |
+
projector_length = nblocks * effective_window_size
|
127 |
+
projector_lengths.append(projector_length)
|
128 |
+
|
129 |
+
return projector_lengths
|
130 |
+
|
131 |
+
|
132 |
import transformers
|
|
|
133 |
transformers.GraniteSpeechFeatureExtractor = GraniteSpeechFeatureExtractor
|
134 |
# The above code is the only change in the modeling code from the following
|
135 |
# commit on Alex's fork: 397e03a4d76c5f3d8a651e47ade9f27c635e1617
|