|
from typing import Dict, List, Optional, Union, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchaudio |
|
from .encoder import ConformerEncoder |
|
from torch import Tensor |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.feature_extraction_sequence_utils import \ |
|
SequenceFeatureExtractor |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.modeling_outputs import CausalLMOutput, Seq2SeqLMOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
class GigaAMCTC(nn.Module): |
|
""" |
|
GigaAM-CTC model |
|
""" |
|
|
|
def __init__(self, config_encoder, config_head): |
|
super().__init__() |
|
self.encoder = ConformerEncoder(**config_encoder) |
|
self.head = CTCHead(**config_head) |
|
|
|
def forward(self, input_features: Tensor, input_lengths: Tensor) -> Tensor: |
|
encoded, encoded_lengths = self.encoder(input_features, input_lengths) |
|
logits = self.head(encoded) |
|
return logits, encoded_lengths |
|
|
|
|
|
class GigaAMRNNT(nn.Module): |
|
""" |
|
GigaAM-RNNT model |
|
""" |
|
|
|
def __init__(self, config_encoder, config_head): |
|
super().__init__() |
|
self.encoder = ConformerEncoder(**config_encoder) |
|
self.head = RNNTHead(**config_head) |
|
|
|
def forward(self, input_features: Tensor, input_lengths: Tensor, targets: Tensor, target_lengths: Tensor) -> Tensor: |
|
encoded, encoded_lengths = self.encoder(input_features, input_lengths) |
|
|
|
decoder_out, target_lengths, states = self.head.decoder(targets=targets, target_length=target_lengths) |
|
joint = self.head.joint(encoder_outputs=encoded, decoder_outputs=decoder_out) |
|
|
|
|
|
|
|
|
|
return joint, encoded_lengths |
|
|
|
|
|
class CTCHead(nn.Module): |
|
""" |
|
CTC Head module for Connectionist Temporal Classification. |
|
""" |
|
|
|
def __init__(self, feat_in: int, num_classes: int): |
|
super().__init__() |
|
self.decoder_layers = nn.Sequential( |
|
nn.Conv1d(feat_in, num_classes, kernel_size=1) |
|
) |
|
|
|
def forward(self, encoder_output: Tensor) -> Tensor: |
|
|
|
return self.decoder_layers(encoder_output) |
|
|
|
|
|
class RNNTJoint(nn.Module): |
|
""" |
|
RNN-Transducer Joint Network Module. |
|
This module combines the outputs of the encoder and the prediction network using |
|
a linear transformation followed by ReLU activation and another linear projection. |
|
""" |
|
|
|
def __init__( |
|
self, enc_hidden: int, pred_hidden: int, joint_hidden: int, num_classes: int |
|
): |
|
super().__init__() |
|
self.enc_hidden = enc_hidden |
|
self.pred_hidden = pred_hidden |
|
self.pred = nn.Linear(pred_hidden, joint_hidden) |
|
self.enc = nn.Linear(enc_hidden, joint_hidden) |
|
self.joint_net = nn.Sequential(nn.ReLU(), nn.Linear(joint_hidden, num_classes)) |
|
|
|
def joint(self, encoder_out: Tensor, decoder_out: Tensor) -> Tensor: |
|
""" |
|
Combine the encoder and prediction network outputs into a joint representation. |
|
""" |
|
enc = self.enc(encoder_out).unsqueeze(2) |
|
pred = self.pred(decoder_out).unsqueeze(1) |
|
return self.joint_net(enc + pred) |
|
|
|
def input_example(self): |
|
device = next(self.parameters()).device |
|
enc = torch.zeros(1, self.enc_hidden, 1) |
|
dec = torch.zeros(1, self.pred_hidden, 1) |
|
return enc.float().to(device), dec.float().to(device) |
|
|
|
def input_names(self): |
|
return ["enc", "dec"] |
|
|
|
def output_names(self): |
|
return ["joint"] |
|
|
|
def forward(self, enc: Tensor, dec: Tensor) -> Tensor: |
|
return self.joint(enc.transpose(1, 2), dec.transpose(1, 2)) |
|
|
|
|
|
class RNNTDecoder(nn.Module): |
|
""" |
|
RNN-Transducer Decoder Module. |
|
This module handles the prediction network part of the RNN-Transducer architecture. |
|
""" |
|
|
|
def __init__(self, pred_hidden: int, pred_rnn_layers: int, num_classes: int): |
|
super().__init__() |
|
self.blank_id = num_classes - 1 |
|
self.pred_hidden = pred_hidden |
|
self.embed = nn.Embedding(num_classes, pred_hidden, padding_idx=self.blank_id) |
|
self.lstm = nn.LSTM(pred_hidden, pred_hidden, pred_rnn_layers) |
|
|
|
def predict( |
|
self, |
|
x: Optional[Tensor], |
|
state: Optional[Tensor], |
|
batch_size: int = 1, |
|
) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Make predictions based on the current input and previous states. |
|
If no input is provided, use zeros as the initial input. |
|
""" |
|
if x is not None: |
|
emb: Tensor = self.embed(x) |
|
else: |
|
emb = torch.zeros( |
|
(batch_size, 1, self.pred_hidden), device=next(self.parameters()).device |
|
) |
|
g, hid = self.lstm(emb.transpose(0, 1), state) |
|
return g.transpose(0, 1), hid |
|
|
|
def input_example(self): |
|
device = next(self.parameters()).device |
|
label = torch.tensor([[0]]).to(device) |
|
hidden_h = torch.zeros(1, 1, self.pred_hidden).to(device) |
|
hidden_c = torch.zeros(1, 1, self.pred_hidden).to(device) |
|
return label, hidden_h, hidden_c |
|
|
|
def input_names(self): |
|
return ["x", "h", "c"] |
|
|
|
def output_names(self): |
|
return ["dec", "h", "c"] |
|
|
|
def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tuple[Tensor, Tensor, Tensor]: |
|
""" |
|
ONNX-specific forward with x, state = (h, c) -> x, h, c. |
|
""" |
|
emb = self.embed(x) |
|
g, (h, c) = self.lstm(emb.transpose(0, 1), (h, c)) |
|
return g.transpose(0, 1), h, c |
|
|
|
|
|
class RNNTHead(nn.Module): |
|
""" |
|
RNN-Transducer Head Module. |
|
This module combines the decoder and joint network components of the RNN-Transducer architecture. |
|
""" |
|
|
|
def __init__(self, decoder: Dict[str, int], joint: Dict[str, int]): |
|
super().__init__() |
|
self.decoder = RNNTDecoder(**decoder) |
|
self.joint = RNNTJoint(**joint) |
|
|
|
|
|
class GigaAMFeatureExtractor(SequenceFeatureExtractor): |
|
""" |
|
Feature extractor for GigaAM. |
|
""" |
|
model_input_names = ["input_features"] |
|
|
|
def __init__( |
|
self, |
|
feature_size=64, |
|
sampling_rate=16000, |
|
padding_value=0.0, |
|
chunk_length=30.0, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
feature_size=feature_size, |
|
sampling_rate=sampling_rate, |
|
padding_value=padding_value, |
|
chunk_length=chunk_length, |
|
**kwargs, |
|
) |
|
self.hop_length = sampling_rate // 100 |
|
self.n_samples = chunk_length * sampling_rate |
|
self.featurizer = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=sampling_rate, |
|
n_fft=sampling_rate // 40, |
|
win_length=sampling_rate // 40, |
|
hop_length=self.hop_length, |
|
n_mels=feature_size, |
|
) |
|
|
|
def to_dict(self) -> Dict[str, Union[str, int, Dict]]: |
|
dictionary = super().to_dict() |
|
|
|
if "featurizer" in dictionary: |
|
del dictionary["featurizer"] |
|
dictionary["hop_length"] = self.hop_length |
|
dictionary["n_samples"] = self.n_samples |
|
return dictionary |
|
|
|
def out_len(self, input_lengths: Tensor) -> Tensor: |
|
""" |
|
Calculates the output length after the feature extraction process. |
|
""" |
|
return input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long() |
|
|
|
def __call__( |
|
self, |
|
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], |
|
sampling_rate: Optional[int] = None, |
|
padding: str = "max_length", |
|
**kwargs, |
|
): |
|
is_batched_numpy = ( |
|
isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 |
|
) |
|
if is_batched_numpy and len(raw_speech.shape) > 2: |
|
raise ValueError( |
|
f"Only mono-channel audio is supported for input to {self}" |
|
) |
|
is_batched = is_batched_numpy or ( |
|
isinstance(raw_speech, (list, tuple)) |
|
and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) |
|
) |
|
|
|
if is_batched: |
|
raw_speech = [ |
|
np.asarray([speech], dtype=np.float32).T for speech in raw_speech |
|
] |
|
elif not is_batched and not isinstance(raw_speech, np.ndarray): |
|
raw_speech = np.asarray(raw_speech, dtype=np.float32) |
|
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype( |
|
np.float64 |
|
): |
|
raw_speech = raw_speech.astype(np.float32) |
|
|
|
|
|
if not is_batched: |
|
raw_speech = [np.asarray([raw_speech]).T] |
|
|
|
input_lengths = torch.tensor([len(speech) for speech in raw_speech]) |
|
|
|
batched_speech = BatchFeature({"input_features": raw_speech}) |
|
|
|
padded_inputs = self.pad( |
|
batched_speech, |
|
padding=padding, |
|
max_length=self.n_samples, |
|
truncation=False, |
|
return_tensors="pt", |
|
) |
|
|
|
input_features = padded_inputs["input_features"].transpose(1, 2) |
|
input_features = self.featurizer(input_features).squeeze(1) |
|
input_features = torch.log(input_features.clamp_(1e-9, 1e9)) |
|
input_lengths = self.out_len(input_lengths) |
|
|
|
return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt") |
|
|
|
|
|
class GigaAMTokenizer(Wav2Vec2CTCTokenizer): |
|
""" |
|
Char tokenizer for GigaAM model. |
|
""" |
|
def __init__( |
|
self, |
|
vocab_file, |
|
unk_token="[BLANK]", |
|
pad_token="[BLANK]", |
|
bos_token=None, |
|
eos_token=None, |
|
word_delimiter_token=" ", |
|
**kwargs, |
|
): |
|
super().__init__( |
|
vocab_file=vocab_file, |
|
unk_token=unk_token, |
|
pad_token=pad_token, |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
word_delimiter_token=word_delimiter_token, |
|
**kwargs, |
|
) |
|
|
|
|
|
class GigaAMProcessor(Wav2Vec2Processor): |
|
feature_extractor_class = "GigaAMFeatureExtractor" |
|
tokenizer_class = "GigaAMTokenizer" |
|
|
|
def __init__(self, feature_extractor, tokenizer): |
|
|
|
self.feature_extractor = feature_extractor |
|
self.tokenizer = tokenizer |
|
self.current_processor = self.feature_extractor |
|
self._in_target_context_manager = False |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
tokenizer = GigaAMTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
|
|
class GigaAMConfig(PretrainedConfig): |
|
model_type = "gigaam" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class GigaAMCTCHF(PreTrainedModel): |
|
""" |
|
GigaAM-CTC model for transformers |
|
""" |
|
config_class = GigaAMConfig |
|
base_model_prefix = "gigaamctc" |
|
main_input_name = "input_features" |
|
|
|
def __init__(self, config: GigaAMConfig): |
|
super().__init__(config) |
|
self.model = GigaAMCTC(config.encoder, config.head) |
|
|
|
def forward(self, input_features, input_lengths, labels=None, **kwargs): |
|
|
|
|
|
logits, encoded_lengths = self.model(input_features, input_lengths) |
|
|
|
log_probs = torch.log_softmax( |
|
logits.transpose(1, 2), dim=-1, dtype=torch.float32 |
|
).transpose(0, 1) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels_mask = labels >= 0 |
|
target_lengths = labels_mask.sum(-1) |
|
flattened_targets = labels.masked_select(labels_mask) |
|
|
|
loss = nn.functional.ctc_loss( |
|
log_probs, |
|
flattened_targets, |
|
encoded_lengths, |
|
target_lengths, |
|
blank=self.config.blank_id, |
|
zero_infinity=True, |
|
) |
|
|
|
return CausalLMOutput(loss=loss, logits=logits.transpose(1, 2)) |
|
|
|
|
|
class GigaAMRNNTHF(PreTrainedModel): |
|
""" |
|
GigaAM-RNNT model for transformers |
|
""" |
|
config_class = GigaAMConfig |
|
base_model_prefix = "gigaamrnnt" |
|
main_input_name = "input_features" |
|
|
|
def __init__(self, config: GigaAMConfig): |
|
super().__init__(config) |
|
self.model = GigaAMRNNT(config.encoder, config.head) |
|
|
|
def forward(self, input_features, input_lengths, labels=None, **kwargs): |
|
|
|
|
|
encoder_out, encoded_lengths = self.model.encoder(input_features, input_lengths) |
|
encoder_out = encoder_out.transpose(1, 2) |
|
batch_size = encoder_out.shape[0] |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(torch.int32) |
|
labels_mask = labels >= 0 |
|
target_lengths = labels_mask.sum(-1).to(torch.int32) |
|
|
|
hidden_states = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device) |
|
hidden_c = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device) |
|
plus_one_dim = self.config.blank_id * torch.ones((batch_size, 1), dtype=torch.int32, device=encoder_out.device) |
|
labels[labels < 0] = self.config.blank_id |
|
|
|
decoder_out, h, c = self.model.head.decoder(torch.cat((plus_one_dim, labels), dim=1), hidden_states, hidden_c) |
|
|
|
joint = self.model.head.joint.joint(encoder_out, decoder_out) |
|
loss = torchaudio.functional.rnnt_loss( |
|
logits=joint, |
|
targets=labels, |
|
logit_lengths=encoded_lengths, |
|
target_lengths=target_lengths, |
|
blank=self.config.blank_id, |
|
) |
|
|
|
return Seq2SeqLMOutput(loss=loss, logits=encoder_out.transpose(1, 2)) |
|
|
|
def _greedy_decode(self, x: Tensor, seqlen: Tensor) -> str: |
|
""" |
|
Internal helper function for performing greedy decoding on a single sequence. |
|
""" |
|
hyp: List[int] = [] |
|
dec_state: Optional[Tensor] = None |
|
last_label: Optional[Tensor] = None |
|
for t in range(seqlen): |
|
f = x[t, :, :].unsqueeze(1) |
|
not_blank = True |
|
new_symbols = 0 |
|
while not_blank and new_symbols < self.config.max_symbols: |
|
g, hidden = self.model.head.decoder.predict(last_label, dec_state) |
|
k = self.model.head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item() |
|
if k == self.config.blank_id: |
|
not_blank = False |
|
else: |
|
hyp.append(k) |
|
dec_state = hidden |
|
last_label = torch.tensor([[hyp[-1]]]).to(x.device) |
|
new_symbols += 1 |
|
|
|
return torch.tensor([hyp], dtype=torch.int32) |
|
|
|
@torch.inference_mode() |
|
def generate(self, input_features: Tensor, input_lengths: Tensor, **kwargs) -> torch.Tensor: |
|
""" |
|
Decode the output of an RNN-T model into a list of hypotheses. |
|
""" |
|
encoder_out, encoded_lengths = self.model.encoder(input_features, input_lengths) |
|
encoder_out = encoder_out.transpose(1, 2) |
|
b = encoder_out.shape[0] |
|
preds = [] |
|
for i in range(b): |
|
inseq = encoder_out[i, :, :].unsqueeze(1) |
|
preds.append(self._greedy_decode(inseq, encoded_lengths[i])) |
|
return pad_sequence(preds, batch_first=True, padding_value=self.config.blank_id) |
|
|