# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Processor class for Speech Granite. """ from collections.abc import Sequence from typing import List, Union import numpy as np import torch from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils import PreTokenizedInput, TextInput from transformers.utils import logging logger = logging.get_logger(__name__) # 🚨🚨🚨 HACK 🚨🚨🚨 # This is needed to avoid custom registration issues for now, # since we have a custom subclass for the feature extractor as well. import math from typing import List, Optional from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin from transformers.utils import is_torch_available, is_torchaudio_available, logging if is_torch_available(): import torch if is_torchaudio_available(): import torchaudio class GraniteSpeechFeatureExtractor(FeatureExtractionMixin): model_input_names = ["input_features"] def __init__( self, sampling_rate=16000, n_fft=512, win_length=400, hop_length=160, n_mels=80, projector_window_size=15, projector_downsample_rate=5, **kwargs, ): super().__init__(**kwargs) self.melspec_kwargs = { "sample_rate": sampling_rate, "n_fft": n_fft, "win_length": win_length, "hop_length": hop_length, "n_mels": n_mels, } # HACK - for now, lazily initialize the mel spectrogram transform; # the feature extractor mixin explodes otherwise because # it tries to log the feature extractor, and the melspectrogram # transform isn't json serializable... self.melspec = None self.projector_window_size = projector_window_size self.projector_downsample_rate = projector_downsample_rate def _ensure_melspec_transform_is_initialized(self): if self.melspec is None: self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) def __call__( self, x: torch.Tensor, device: Optional[str] = "cpu", ) -> BatchFeature: # TODO there is probably a better way to do both of these things... self._ensure_melspec_transform_is_initialized() if device is not None: melspec = self.melspec.to(device) x = x.to(device) else: melspec = self.melspec B, _ = x.shape with torch.no_grad(): mel = melspec(x.float()) logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_() mx = logmel.amax(dim=(-2, -1), keepdim=True) logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) if logmel.shape[1] % 2 == 1: logmel = logmel[:, :-1] # remove last frame if odd x = logmel.reshape(B, -1, 2 * logmel.shape[-1]) # stacking and skipping by 2 if x.device != "cpu": return x.detach().cpu() return x def _get_num_audio_features(self, audio_lengths: List[int]) -> List[int]: """ Gets the (variable length) variable length number of features (i.e., projector output) for the sequences being considered. """ hop_length = self.melspec_kwargs["hop_length"] effective_window_size = self.projector_window_size // self.projector_downsample_rate projector_lengths = [] for raw_length in audio_lengths: # mel sequence length computation mel_length = raw_length // hop_length + 1 # encoder frame takes two mel features encoder_length = mel_length // 2 nblocks = math.ceil(encoder_length / self.projector_window_size) # projector output length projector_length = nblocks * effective_window_size projector_lengths.append(projector_length) return projector_lengths import transformers transformers.GraniteSpeechFeatureExtractor = GraniteSpeechFeatureExtractor # The above code is the only change in the modeling code from the following # commit on Alex's fork: 397e03a4d76c5f3d8a651e47ade9f27c635e1617 class GraniteSpeechProcessor(ProcessorMixin): attributes = ["feature_extractor", "tokenizer"] valid_kwargs = ["audio_token"] feature_extractor_class = "GraniteSpeechFeatureExtractor" tokenizer_class = "AutoTokenizer" def __init__( self, feature_extractor, tokenizer, audio_token="<|audio|>", ): self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token super().__init__(feature_extractor, tokenizer) def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], audios: Union[torch.Tensor, List[torch.Tensor]] = None, device: str = "cpu", **kwargs, ) -> BatchFeature: speech_inputs = {} text_inputs = {} text = self._get_validated_text(text) expected_num_audios = sum(t.count(self.audio_token) for t in text) if audios is not None: audios, audio_lengths = self._get_validated_audios(audios) if any(text.count(self.audio_token) != 1 for text in text): raise ValueError("Only one audio sample is currently supported per input") if len(audio_lengths) != expected_num_audios: raise ValueError("Text/Audio mismatch. The number of audios and audio tokens do not match") # Calculate Mel features & the number of placeholders we will need speech_inputs["input_features"] = self.feature_extractor( audios, device=device, ) num_audio_features = self.feature_extractor._get_num_audio_features(audio_lengths) speech_inputs["input_features_mask"] = torch.arange(max(num_audio_features)).view(1, -1) <= torch.tensor( num_audio_features ).view(-1, 1) # duplicate the audio placeholders to match the feature dims text = self._expand_audio_placeholders(text, num_audio_features) else: assert expected_num_audios == 0, "No audio is provided, expecting no audio tokens" text_inputs = self.tokenizer(text, padding=True, **kwargs) return BatchFeature(data={**text_inputs, **speech_inputs}) def _expand_audio_placeholders(self, text: list[str], num_audio_features: List[int]): """ Expands audio placeholders in the formatted text to match the number of features of the corresponding embeddings; we can use the resulting text to conveniently mask the audio features into the text embeddings. """ prompt_strings = [] num_replaced = 0 for sample in text: while self.audio_token in sample: sample = sample.replace( self.audio_token, "" * num_audio_features[num_replaced], 1, ) num_replaced += 1 prompt_strings.append(sample) prompt_strings = [sample.replace("", self.audio_token) for sample in prompt_strings] return prompt_strings ##### Validation def _get_validated_text(self, text: Union[str, list]) -> List[str]: if isinstance(text, str): return [text] elif isinstance(text, list) and isinstance(text[0], str): return text raise TypeError("Invalid text provided! Text should be a string or list of strings.") def _get_validated_audios(self, audios): # Coerce to PyTorch tensors if we have numpy arrays, since # currently we have a dependency on torch/torchaudio anyway if isinstance(audios, np.ndarray): audios = torch.from_numpy(audios) elif isinstance(audios, Sequence) and isinstance(audios[0], np.ndarray): audios = [torch.from_numpy(arr) for arr in audios] if isinstance(audios, torch.Tensor): if audios.ndim == 1: audios = audios.unsqueeze(0) if not torch.is_floating_point(audios): raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") if audios.shape[0] > 1: logger.warning("Audio samples are already collated; assuming they all have the same length") lengths = [audios.shape[-1]] * audios.shape[0] return audios, lengths elif isinstance(audios, Sequence) and isinstance(audios[0], torch.Tensor): if not torch.is_floating_point(audios[0]): raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") lengths = [audio.shape[-1] for audio in audios] padding = [max(lengths) - length for length in lengths] # ensure all audios have a batch dimension: audios = [audio.view(1, -1) for audio in audios] padded = [torch.nn.functional.pad(audio, (0, pad)) for audio, pad in zip(audios, padding)] audios = torch.cat(padded, dim=0) return audios, lengths raise TypeError("Invalid audio provided. Audio should be a one or more torch tensors or numpy arrays") __all__ = ["GraniteSpeechProcessor"]