|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for Speech Granite. |
|
""" |
|
|
|
from collections.abc import Sequence |
|
from typing import List, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.tokenization_utils import PreTokenizedInput, TextInput |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
import math |
|
from typing import List, Optional |
|
|
|
from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin |
|
from transformers.utils import is_torch_available, is_torchaudio_available, logging |
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
if is_torchaudio_available(): |
|
import torchaudio |
|
|
|
|
|
class GraniteSpeechFeatureExtractor(FeatureExtractionMixin): |
|
model_input_names = ["input_features"] |
|
|
|
def __init__( |
|
self, |
|
sampling_rate=16000, |
|
n_fft=512, |
|
win_length=400, |
|
hop_length=160, |
|
n_mels=80, |
|
projector_window_size=15, |
|
projector_downsample_rate=5, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.melspec_kwargs = { |
|
"sample_rate": sampling_rate, |
|
"n_fft": n_fft, |
|
"win_length": win_length, |
|
"hop_length": hop_length, |
|
"n_mels": n_mels, |
|
} |
|
|
|
|
|
|
|
|
|
self.melspec = None |
|
self.projector_window_size = projector_window_size |
|
self.projector_downsample_rate = projector_downsample_rate |
|
|
|
def _ensure_melspec_transform_is_initialized(self): |
|
if self.melspec is None: |
|
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) |
|
|
|
def __call__( |
|
self, |
|
x: torch.Tensor, |
|
device: Optional[str] = "cpu", |
|
) -> BatchFeature: |
|
|
|
self._ensure_melspec_transform_is_initialized() |
|
if device is not None: |
|
melspec = self.melspec.to(device) |
|
x = x.to(device) |
|
else: |
|
melspec = self.melspec |
|
|
|
B, _ = x.shape |
|
with torch.no_grad(): |
|
mel = melspec(x.float()) |
|
logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_() |
|
mx = logmel.amax(dim=(-2, -1), keepdim=True) |
|
logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) |
|
if logmel.shape[1] % 2 == 1: |
|
logmel = logmel[:, :-1] |
|
x = logmel.reshape(B, -1, 2 * logmel.shape[-1]) |
|
|
|
if x.device != "cpu": |
|
return x.detach().cpu() |
|
return x |
|
|
|
def _get_num_audio_features(self, audio_lengths: List[int]) -> List[int]: |
|
""" |
|
Gets the (variable length) variable length number of features |
|
(i.e., projector output) for the sequences being considered. |
|
""" |
|
hop_length = self.melspec_kwargs["hop_length"] |
|
effective_window_size = self.projector_window_size // self.projector_downsample_rate |
|
|
|
projector_lengths = [] |
|
for raw_length in audio_lengths: |
|
|
|
mel_length = raw_length // hop_length + 1 |
|
|
|
encoder_length = mel_length // 2 |
|
nblocks = math.ceil(encoder_length / self.projector_window_size) |
|
|
|
projector_length = nblocks * effective_window_size |
|
projector_lengths.append(projector_length) |
|
|
|
return projector_lengths |
|
|
|
|
|
import transformers |
|
transformers.GraniteSpeechFeatureExtractor = GraniteSpeechFeatureExtractor |
|
|
|
|
|
|
|
class GraniteSpeechProcessor(ProcessorMixin): |
|
attributes = ["feature_extractor", "tokenizer"] |
|
valid_kwargs = ["audio_token"] |
|
|
|
feature_extractor_class = "GraniteSpeechFeatureExtractor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__( |
|
self, |
|
feature_extractor, |
|
tokenizer, |
|
audio_token="<|audio|>", |
|
): |
|
self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token |
|
super().__init__(feature_extractor, tokenizer) |
|
|
|
def __call__( |
|
self, |
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], |
|
audios: Union[torch.Tensor, List[torch.Tensor]] = None, |
|
device: str = "cpu", |
|
**kwargs, |
|
) -> BatchFeature: |
|
speech_inputs = {} |
|
text_inputs = {} |
|
|
|
text = self._get_validated_text(text) |
|
expected_num_audios = sum(t.count(self.audio_token) for t in text) |
|
|
|
if audios is not None: |
|
audios, audio_lengths = self._get_validated_audios(audios) |
|
if any(text.count(self.audio_token) != 1 for text in text): |
|
raise ValueError("Only one audio sample is currently supported per input") |
|
if len(audio_lengths) != expected_num_audios: |
|
raise ValueError("Text/Audio mismatch. The number of audios and audio tokens do not match") |
|
|
|
|
|
speech_inputs["input_features"] = self.feature_extractor( |
|
audios, |
|
device=device, |
|
) |
|
num_audio_features = self.feature_extractor._get_num_audio_features(audio_lengths) |
|
speech_inputs["input_features_mask"] = torch.arange(max(num_audio_features)).view(1, -1) <= torch.tensor( |
|
num_audio_features |
|
).view(-1, 1) |
|
|
|
|
|
text = self._expand_audio_placeholders(text, num_audio_features) |
|
else: |
|
assert expected_num_audios == 0, "No audio is provided, expecting no audio tokens" |
|
|
|
text_inputs = self.tokenizer(text, padding=True, **kwargs) |
|
return BatchFeature(data={**text_inputs, **speech_inputs}) |
|
|
|
def _expand_audio_placeholders(self, text: list[str], num_audio_features: List[int]): |
|
""" |
|
Expands audio placeholders in the formatted text to match the number of |
|
features of the corresponding embeddings; we can use the resulting text |
|
to conveniently mask the audio features into the text embeddings. |
|
""" |
|
prompt_strings = [] |
|
num_replaced = 0 |
|
for sample in text: |
|
while self.audio_token in sample: |
|
sample = sample.replace( |
|
self.audio_token, |
|
"<placeholder>" * num_audio_features[num_replaced], |
|
1, |
|
) |
|
num_replaced += 1 |
|
prompt_strings.append(sample) |
|
|
|
prompt_strings = [sample.replace("<placeholder>", self.audio_token) for sample in prompt_strings] |
|
return prompt_strings |
|
|
|
|
|
def _get_validated_text(self, text: Union[str, list]) -> List[str]: |
|
if isinstance(text, str): |
|
return [text] |
|
elif isinstance(text, list) and isinstance(text[0], str): |
|
return text |
|
raise TypeError("Invalid text provided! Text should be a string or list of strings.") |
|
|
|
def _get_validated_audios(self, audios): |
|
|
|
|
|
if isinstance(audios, np.ndarray): |
|
audios = torch.from_numpy(audios) |
|
elif isinstance(audios, Sequence) and isinstance(audios[0], np.ndarray): |
|
audios = [torch.from_numpy(arr) for arr in audios] |
|
|
|
if isinstance(audios, torch.Tensor): |
|
if audios.ndim == 1: |
|
audios = audios.unsqueeze(0) |
|
if not torch.is_floating_point(audios): |
|
raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") |
|
|
|
if audios.shape[0] > 1: |
|
logger.warning("Audio samples are already collated; assuming they all have the same length") |
|
lengths = [audios.shape[-1]] * audios.shape[0] |
|
return audios, lengths |
|
|
|
elif isinstance(audios, Sequence) and isinstance(audios[0], torch.Tensor): |
|
if not torch.is_floating_point(audios[0]): |
|
raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") |
|
lengths = [audio.shape[-1] for audio in audios] |
|
padding = [max(lengths) - length for length in lengths] |
|
|
|
audios = [audio.view(1, -1) for audio in audios] |
|
padded = [torch.nn.functional.pad(audio, (0, pad)) for audio, pad in zip(audios, padding)] |
|
audios = torch.cat(padded, dim=0) |
|
return audios, lengths |
|
|
|
raise TypeError("Invalid audio provided. Audio should be a one or more torch tensors or numpy arrays") |
|
|
|
|
|
__all__ = ["GraniteSpeechProcessor"] |
|
|