gsaon commited on
Commit
faa27e8
Β·
verified Β·
1 Parent(s): dc71e31

Upload processing_granite_speech.py

Browse files
Files changed (1) hide show
  1. processing_granite_speech.py +96 -1
processing_granite_speech.py CHANGED
@@ -33,8 +33,103 @@ logger = logging.get_logger(__name__)
33
  # 🚨🚨🚨 HACK 🚨🚨🚨
34
  # This is needed to avoid custom registration issues for now,
35
  # since we have a custom subclass for the feature extractor as well.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  import transformers
37
- from .feature_extraction_granite_speech import GraniteSpeechFeatureExtractor
38
  transformers.GraniteSpeechFeatureExtractor = GraniteSpeechFeatureExtractor
39
  # The above code is the only change in the modeling code from the following
40
  # commit on Alex's fork: 397e03a4d76c5f3d8a651e47ade9f27c635e1617
 
33
  # 🚨🚨🚨 HACK 🚨🚨🚨
34
  # This is needed to avoid custom registration issues for now,
35
  # since we have a custom subclass for the feature extractor as well.
36
+ import math
37
+ from typing import List, Optional
38
+
39
+ from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
40
+ from transformers.utils import is_torch_available, is_torchaudio_available, logging
41
+
42
+ if is_torch_available():
43
+ import torch
44
+
45
+ if is_torchaudio_available():
46
+ import torchaudio
47
+
48
+
49
+ class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
50
+ model_input_names = ["input_features"]
51
+
52
+ def __init__(
53
+ self,
54
+ sampling_rate=16000,
55
+ n_fft=512,
56
+ win_length=400,
57
+ hop_length=160,
58
+ n_mels=80,
59
+ projector_window_size=15,
60
+ projector_downsample_rate=5,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(**kwargs)
64
+ self.melspec_kwargs = {
65
+ "sample_rate": sampling_rate,
66
+ "n_fft": n_fft,
67
+ "win_length": win_length,
68
+ "hop_length": hop_length,
69
+ "n_mels": n_mels,
70
+ }
71
+ # HACK - for now, lazily initialize the mel spectrogram transform;
72
+ # the feature extractor mixin explodes otherwise because
73
+ # it tries to log the feature extractor, and the melspectrogram
74
+ # transform isn't json serializable...
75
+ self.melspec = None
76
+ self.projector_window_size = projector_window_size
77
+ self.projector_downsample_rate = projector_downsample_rate
78
+
79
+ def _ensure_melspec_transform_is_initialized(self):
80
+ if self.melspec is None:
81
+ self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
82
+
83
+ def __call__(
84
+ self,
85
+ x: torch.Tensor,
86
+ device: Optional[str] = "cpu",
87
+ ) -> BatchFeature:
88
+ # TODO there is probably a better way to do both of these things...
89
+ self._ensure_melspec_transform_is_initialized()
90
+ if device is not None:
91
+ melspec = self.melspec.to(device)
92
+ x = x.to(device)
93
+ else:
94
+ melspec = self.melspec
95
+
96
+ B, _ = x.shape
97
+ with torch.no_grad():
98
+ mel = melspec(x.float())
99
+ logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_()
100
+ mx = logmel.amax(dim=(-2, -1), keepdim=True)
101
+ logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1)
102
+ if logmel.shape[1] % 2 == 1:
103
+ logmel = logmel[:, :-1] # remove last frame if odd
104
+ x = logmel.reshape(B, -1, 2 * logmel.shape[-1]) # stacking and skipping by 2
105
+
106
+ if x.device != "cpu":
107
+ return x.detach().cpu()
108
+ return x
109
+
110
+ def _get_num_audio_features(self, audio_lengths: List[int]) -> List[int]:
111
+ """
112
+ Gets the (variable length) variable length number of features
113
+ (i.e., projector output) for the sequences being considered.
114
+ """
115
+ hop_length = self.melspec_kwargs["hop_length"]
116
+ effective_window_size = self.projector_window_size // self.projector_downsample_rate
117
+
118
+ projector_lengths = []
119
+ for raw_length in audio_lengths:
120
+ # mel sequence length computation
121
+ mel_length = raw_length // hop_length + 1
122
+ # encoder frame takes two mel features
123
+ encoder_length = mel_length // 2
124
+ nblocks = math.ceil(encoder_length / self.projector_window_size)
125
+ # projector output length
126
+ projector_length = nblocks * effective_window_size
127
+ projector_lengths.append(projector_length)
128
+
129
+ return projector_lengths
130
+
131
+
132
  import transformers
 
133
  transformers.GraniteSpeechFeatureExtractor = GraniteSpeechFeatureExtractor
134
  # The above code is the only change in the modeling code from the following
135
  # commit on Alex's fork: 397e03a4d76c5f3d8a651e47ade9f27c635e1617