|
from typing import Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchaudio |
|
from .encoder import ConformerEncoder |
|
from torch import Tensor |
|
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.feature_extraction_sequence_utils import \ |
|
SequenceFeatureExtractor |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.modeling_outputs import CausalLMOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
class GigaAMCTC(nn.Module): |
|
""" |
|
GigaAM-CTC model |
|
""" |
|
|
|
def __init__(self, config_encoder, config_head): |
|
super().__init__() |
|
self.encoder = ConformerEncoder(**config_encoder) |
|
self.head = CTCHead(**config_head) |
|
|
|
def forward(self, input_features: Tensor, input_lengths: Tensor) -> Tensor: |
|
encoded, encoded_lengths = self.encoder(input_features, input_lengths) |
|
logits = self.head(encoded) |
|
return logits, encoded_lengths |
|
|
|
|
|
class CTCHead(nn.Module): |
|
""" |
|
CTC Head module for Connectionist Temporal Classification. |
|
""" |
|
|
|
def __init__(self, feat_in: int, num_classes: int): |
|
super().__init__() |
|
self.decoder_layers = nn.Sequential( |
|
nn.Conv1d(feat_in, num_classes, kernel_size=1) |
|
) |
|
|
|
def forward(self, encoder_output: Tensor) -> Tensor: |
|
|
|
return self.decoder_layers(encoder_output) |
|
|
|
|
|
class GigaAMFeatureExtractor(SequenceFeatureExtractor): |
|
""" |
|
Feature extractor for GigaAM. |
|
""" |
|
model_input_names = ["input_features"] |
|
|
|
def __init__( |
|
self, |
|
feature_size=64, |
|
sampling_rate=16000, |
|
padding_value=0.0, |
|
chunk_length=30.0, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
feature_size=feature_size, |
|
sampling_rate=sampling_rate, |
|
padding_value=padding_value, |
|
chunk_length=chunk_length, |
|
**kwargs, |
|
) |
|
self.hop_length = sampling_rate // 100 |
|
self.n_samples = chunk_length * sampling_rate |
|
self.featurizer = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=sampling_rate, |
|
n_fft=sampling_rate // 40, |
|
win_length=sampling_rate // 40, |
|
hop_length=self.hop_length, |
|
n_mels=feature_size, |
|
) |
|
|
|
def to_dict(self) -> Dict[str, Union[str, int, Dict]]: |
|
dictionary = super().to_dict() |
|
|
|
if "featurizer" in dictionary: |
|
del dictionary["featurizer"] |
|
dictionary["hop_length"] = self.hop_length |
|
dictionary["n_samples"] = self.n_samples |
|
return dictionary |
|
|
|
def out_len(self, input_lengths: Tensor) -> Tensor: |
|
""" |
|
Calculates the output length after the feature extraction process. |
|
""" |
|
return input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long() |
|
|
|
def __call__( |
|
self, |
|
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], |
|
sampling_rate: Optional[int] = None, |
|
padding: str = "max_length", |
|
**kwargs, |
|
): |
|
is_batched_numpy = ( |
|
isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 |
|
) |
|
if is_batched_numpy and len(raw_speech.shape) > 2: |
|
raise ValueError( |
|
f"Only mono-channel audio is supported for input to {self}" |
|
) |
|
is_batched = is_batched_numpy or ( |
|
isinstance(raw_speech, (list, tuple)) |
|
and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) |
|
) |
|
|
|
if is_batched: |
|
raw_speech = [ |
|
np.asarray([speech], dtype=np.float32).T for speech in raw_speech |
|
] |
|
elif not is_batched and not isinstance(raw_speech, np.ndarray): |
|
raw_speech = np.asarray(raw_speech, dtype=np.float32) |
|
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype( |
|
np.float64 |
|
): |
|
raw_speech = raw_speech.astype(np.float32) |
|
|
|
|
|
if not is_batched: |
|
raw_speech = [np.asarray([raw_speech]).T] |
|
|
|
input_lengths = torch.tensor([len(speech) for speech in raw_speech]) |
|
|
|
batched_speech = BatchFeature({"input_features": raw_speech}) |
|
|
|
padded_inputs = self.pad( |
|
batched_speech, |
|
padding=padding, |
|
max_length=self.n_samples, |
|
truncation=False, |
|
return_tensors="pt", |
|
) |
|
|
|
input_features = padded_inputs["input_features"].transpose(1, 2) |
|
input_features = self.featurizer(input_features).squeeze(1) |
|
input_features = torch.log(input_features.clamp_(1e-9, 1e9)) |
|
input_lengths = self.out_len(input_lengths) |
|
|
|
return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt") |
|
|
|
|
|
class GigaAMCTCTokenizer(Wav2Vec2CTCTokenizer): |
|
""" |
|
Char tokenizer for GigaAM-CTC model. |
|
""" |
|
def __init__( |
|
self, |
|
vocab_file, |
|
unk_token="[BLANK]", |
|
pad_token="[BLANK]", |
|
bos_token=None, |
|
eos_token=None, |
|
word_delimiter_token=" ", |
|
**kwargs, |
|
): |
|
super().__init__( |
|
vocab_file=vocab_file, |
|
unk_token=unk_token, |
|
pad_token=pad_token, |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
word_delimiter_token=word_delimiter_token, |
|
**kwargs, |
|
) |
|
|
|
|
|
class GigaAMProcessor(Wav2Vec2Processor): |
|
feature_extractor_class = "GigaAMFeatureExtractor" |
|
tokenizer_class = "GigaAMCTCTokenizer" |
|
|
|
def __init__(self, feature_extractor, tokenizer): |
|
|
|
self.feature_extractor = feature_extractor |
|
self.tokenizer = tokenizer |
|
self.current_processor = self.feature_extractor |
|
self._in_target_context_manager = False |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
tokenizer = GigaAMCTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
|
|
class GigaAMConfig(PretrainedConfig): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class GigaAMCTCHF(PreTrainedModel): |
|
""" |
|
GigaAM-CTC model for transformers |
|
""" |
|
config_class = GigaAMConfig |
|
base_model_prefix = "gigaamctc" |
|
main_input_name = "input_features" |
|
|
|
def __init__(self, config: GigaAMConfig): |
|
super().__init__(config) |
|
self.model = GigaAMCTC(config.encoder, config.head) |
|
|
|
def forward(self, input_features, input_lengths, labels=None, **kwargs): |
|
|
|
|
|
logits, encoded_lengths = self.model(input_features, input_lengths) |
|
|
|
log_probs = torch.log_softmax( |
|
logits.transpose(1, 2), dim=-1, dtype=torch.float32 |
|
).transpose(0, 1) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels_mask = labels >= 0 |
|
target_lengths = labels_mask.sum(-1) |
|
flattened_targets = labels.masked_select(labels_mask) |
|
|
|
loss = nn.functional.ctc_loss( |
|
log_probs, |
|
flattened_targets, |
|
encoded_lengths, |
|
target_lengths, |
|
blank=self.config.blank_id, |
|
zero_infinity=True, |
|
) |
|
|
|
return CausalLMOutput(loss=loss, logits=logits.transpose(1, 2)) |
|
|