File size: 15,992 Bytes
9b9aeff 8d2bbab 9b9aeff 5a4f890 9b9aeff 5a4f890 9b9aeff 5a4f890 9b9aeff 5a4f890 9b9aeff a3fc5de 9b9aeff e4cbc9d 9b9aeff 8d2bbab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 |
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from .encoder import ConformerEncoder
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
from transformers.configuration_utils import PretrainedConfig
from transformers.feature_extraction_sequence_utils import \
SequenceFeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import CausalLMOutput, Seq2SeqLMOutput
from transformers.modeling_utils import PreTrainedModel
class GigaAMCTC(nn.Module):
"""
GigaAM-CTC model
"""
def __init__(self, config_encoder, config_head):
super().__init__()
self.encoder = ConformerEncoder(**config_encoder)
self.head = CTCHead(**config_head)
def forward(self, input_features: Tensor, input_lengths: Tensor) -> Tensor:
encoded, encoded_lengths = self.encoder(input_features, input_lengths)
logits = self.head(encoded)
return logits, encoded_lengths
class GigaAMRNNT(nn.Module):
"""
GigaAM-RNNT model
"""
def __init__(self, config_encoder, config_head):
super().__init__()
self.encoder = ConformerEncoder(**config_encoder)
self.head = RNNTHead(**config_head)
def forward(self, input_features: Tensor, input_lengths: Tensor, targets: Tensor, target_lengths: Tensor) -> Tensor:
encoded, encoded_lengths = self.encoder(input_features, input_lengths)
# During training, loss must be computed, so decoder forward is necessary
decoder_out, target_lengths, states = self.head.decoder(targets=targets, target_length=target_lengths)
joint = self.head.joint(encoder_outputs=encoded, decoder_outputs=decoder_out)
# loss = self.loss(
# log_probs=joint, targets=targets, input_lengths=encoded_lengths, target_lengths=target_lengths
# )
return joint, encoded_lengths
class CTCHead(nn.Module):
"""
CTC Head module for Connectionist Temporal Classification.
"""
def __init__(self, feat_in: int, num_classes: int):
super().__init__()
self.decoder_layers = nn.Sequential(
nn.Conv1d(feat_in, num_classes, kernel_size=1)
)
def forward(self, encoder_output: Tensor) -> Tensor:
# B x C x T
return self.decoder_layers(encoder_output)
class RNNTJoint(nn.Module):
"""
RNN-Transducer Joint Network Module.
This module combines the outputs of the encoder and the prediction network using
a linear transformation followed by ReLU activation and another linear projection.
"""
def __init__(
self, enc_hidden: int, pred_hidden: int, joint_hidden: int, num_classes: int
):
super().__init__()
self.enc_hidden = enc_hidden
self.pred_hidden = pred_hidden
self.pred = nn.Linear(pred_hidden, joint_hidden)
self.enc = nn.Linear(enc_hidden, joint_hidden)
self.joint_net = nn.Sequential(nn.ReLU(), nn.Linear(joint_hidden, num_classes))
def joint(self, encoder_out: Tensor, decoder_out: Tensor) -> Tensor:
"""
Combine the encoder and prediction network outputs into a joint representation.
"""
enc = self.enc(encoder_out).unsqueeze(2)
pred = self.pred(decoder_out).unsqueeze(1)
return self.joint_net(enc + pred)
def input_example(self):
device = next(self.parameters()).device
enc = torch.zeros(1, self.enc_hidden, 1)
dec = torch.zeros(1, self.pred_hidden, 1)
return enc.float().to(device), dec.float().to(device)
def input_names(self):
return ["enc", "dec"]
def output_names(self):
return ["joint"]
def forward(self, enc: Tensor, dec: Tensor) -> Tensor:
return self.joint(enc.transpose(1, 2), dec.transpose(1, 2))
class RNNTDecoder(nn.Module):
"""
RNN-Transducer Decoder Module.
This module handles the prediction network part of the RNN-Transducer architecture.
"""
def __init__(self, pred_hidden: int, pred_rnn_layers: int, num_classes: int):
super().__init__()
self.blank_id = num_classes - 1
self.pred_hidden = pred_hidden
self.embed = nn.Embedding(num_classes, pred_hidden, padding_idx=self.blank_id)
self.lstm = nn.LSTM(pred_hidden, pred_hidden, pred_rnn_layers)
def predict(
self,
x: Optional[Tensor],
state: Optional[Tensor],
batch_size: int = 1,
) -> Tuple[Tensor, Tensor]:
"""
Make predictions based on the current input and previous states.
If no input is provided, use zeros as the initial input.
"""
if x is not None:
emb: Tensor = self.embed(x)
else:
emb = torch.zeros(
(batch_size, 1, self.pred_hidden), device=next(self.parameters()).device
)
g, hid = self.lstm(emb.transpose(0, 1), state)
return g.transpose(0, 1), hid
def input_example(self):
device = next(self.parameters()).device
label = torch.tensor([[0]]).to(device)
hidden_h = torch.zeros(1, 1, self.pred_hidden).to(device)
hidden_c = torch.zeros(1, 1, self.pred_hidden).to(device)
return label, hidden_h, hidden_c
def input_names(self):
return ["x", "h", "c"]
def output_names(self):
return ["dec", "h", "c"]
def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
ONNX-specific forward with x, state = (h, c) -> x, h, c.
"""
emb = self.embed(x)
g, (h, c) = self.lstm(emb.transpose(0, 1), (h, c))
return g.transpose(0, 1), h, c
class RNNTHead(nn.Module):
"""
RNN-Transducer Head Module.
This module combines the decoder and joint network components of the RNN-Transducer architecture.
"""
def __init__(self, decoder: Dict[str, int], joint: Dict[str, int]):
super().__init__()
self.decoder = RNNTDecoder(**decoder)
self.joint = RNNTJoint(**joint)
class GigaAMFeatureExtractor(SequenceFeatureExtractor):
"""
Feature extractor for GigaAM.
"""
model_input_names = ["input_features"]
def __init__(
self,
feature_size=64,
sampling_rate=16000,
padding_value=0.0,
chunk_length=30.0,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
chunk_length=chunk_length,
**kwargs,
)
self.hop_length = sampling_rate // 100
self.n_samples = chunk_length * sampling_rate
self.featurizer = torchaudio.transforms.MelSpectrogram(
sample_rate=sampling_rate,
n_fft=sampling_rate // 40,
win_length=sampling_rate // 40,
hop_length=self.hop_length,
n_mels=feature_size,
)
def to_dict(self) -> Dict[str, Union[str, int, Dict]]:
dictionary = super().to_dict()
if "featurizer" in dictionary:
del dictionary["featurizer"]
dictionary["hop_length"] = self.hop_length
dictionary["n_samples"] = self.n_samples
return dictionary
def out_len(self, input_lengths: Tensor) -> Tensor:
"""
Calculates the output length after the feature extraction process.
"""
return input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
sampling_rate: Optional[int] = None,
padding: str = "max_length",
**kwargs,
):
is_batched_numpy = (
isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
)
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(
f"Only mono-channel audio is supported for input to {self}"
)
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_speech = [
np.asarray([speech], dtype=np.float32).T for speech in raw_speech
]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
np.float64
):
raw_speech = raw_speech.astype(np.float32)
# always return batch
if not is_batched:
raw_speech = [np.asarray([raw_speech]).T]
input_lengths = torch.tensor([len(speech) for speech in raw_speech])
batched_speech = BatchFeature({"input_features": raw_speech})
padded_inputs = self.pad(
batched_speech,
padding=padding,
max_length=self.n_samples,
truncation=False,
return_tensors="pt",
)
input_features = padded_inputs["input_features"].transpose(1, 2)
input_features = self.featurizer(input_features).squeeze(1)
input_features = torch.log(input_features.clamp_(1e-9, 1e9))
input_lengths = self.out_len(input_lengths)
return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
class GigaAMTokenizer(Wav2Vec2CTCTokenizer):
"""
Char tokenizer for GigaAM model.
"""
def __init__(
self,
vocab_file,
unk_token="[BLANK]",
pad_token="[BLANK]",
bos_token=None,
eos_token=None,
word_delimiter_token=" ",
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
word_delimiter_token=word_delimiter_token,
**kwargs,
)
class GigaAMProcessor(Wav2Vec2Processor):
feature_extractor_class = "GigaAMFeatureExtractor"
tokenizer_class = "GigaAMTokenizer"
def __init__(self, feature_extractor, tokenizer):
# super().__init__(feature_extractor, tokenizer)
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.current_processor = self.feature_extractor
self._in_target_context_manager = False
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = GigaAMTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
class GigaAMConfig(PretrainedConfig):
model_type = "gigaam"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class GigaAMCTCHF(PreTrainedModel):
"""
GigaAM-CTC model for transformers
"""
config_class = GigaAMConfig
base_model_prefix = "gigaamctc"
main_input_name = "input_features"
def __init__(self, config: GigaAMConfig):
super().__init__(config)
self.model = GigaAMCTC(config.encoder, config.head)
def forward(self, input_features, input_lengths, labels=None, **kwargs):
# B x C x T
logits, encoded_lengths = self.model(input_features, input_lengths)
# B x C x T -> B x T x C -> T x B x C
log_probs = torch.log_softmax(
logits.transpose(1, 2), dim=-1, dtype=torch.float32
).transpose(0, 1)
loss = None
if labels is not None:
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
encoded_lengths,
target_lengths,
blank=self.config.blank_id,
zero_infinity=True,
)
return CausalLMOutput(loss=loss, logits=logits.transpose(1, 2))
class GigaAMRNNTHF(PreTrainedModel):
"""
GigaAM-RNNT model for transformers
"""
config_class = GigaAMConfig
base_model_prefix = "gigaamrnnt"
main_input_name = "input_features"
def __init__(self, config: GigaAMConfig):
super().__init__(config)
self.model = GigaAMRNNT(config.encoder, config.head)
def forward(self, input_features, input_lengths, labels=None, **kwargs):
# B x C x T
encoder_out, encoded_lengths = self.model.encoder(input_features, input_lengths)
encoder_out = encoder_out.transpose(1, 2)
batch_size = encoder_out.shape[0]
loss = None
if labels is not None:
labels = labels.to(torch.int32)
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1).to(torch.int32)
hidden_states = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
hidden_c = torch.zeros((self.config.head["decoder"]["pred_rnn_layers"], batch_size, self.model.head.decoder.pred_hidden), device=encoder_out.device)
plus_one_dim = self.config.blank_id * torch.ones((batch_size, 1), dtype=torch.int32, device=encoder_out.device)
labels[labels < 0] = self.config.blank_id
decoder_out, h, c = self.model.head.decoder(torch.cat((plus_one_dim, labels), dim=1), hidden_states, hidden_c)
joint = self.model.head.joint.joint(encoder_out, decoder_out)
loss = torchaudio.functional.rnnt_loss(
logits=joint,
targets=labels,
logit_lengths=encoded_lengths,
target_lengths=target_lengths,
blank=self.config.blank_id,
)
return Seq2SeqLMOutput(loss=loss, logits=encoder_out.transpose(1, 2))
def _greedy_decode(self, x: Tensor, seqlen: Tensor) -> str:
"""
Internal helper function for performing greedy decoding on a single sequence.
"""
hyp: List[int] = []
dec_state: Optional[Tensor] = None
last_label: Optional[Tensor] = None
for t in range(seqlen):
f = x[t, :, :].unsqueeze(1)
not_blank = True
new_symbols = 0
while not_blank and new_symbols < self.config.max_symbols:
g, hidden = self.model.head.decoder.predict(last_label, dec_state)
k = self.model.head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item()
if k == self.config.blank_id:
not_blank = False
else:
hyp.append(k)
dec_state = hidden
last_label = torch.tensor([[hyp[-1]]]).to(x.device)
new_symbols += 1
return torch.tensor([hyp], dtype=torch.int32)
@torch.inference_mode()
def generate(self, input_features: Tensor, input_lengths: Tensor, **kwargs) -> torch.Tensor:
"""
Decode the output of an RNN-T model into a list of hypotheses.
"""
encoder_out, encoded_lengths = self.model.encoder(input_features, input_lengths)
encoder_out = encoder_out.transpose(1, 2)
b = encoder_out.shape[0]
preds = []
for i in range(b):
inseq = encoder_out[i, :, :].unsqueeze(1)
preds.append(self._greedy_decode(inseq, encoded_lengths[i]))
return pad_sequence(preds, batch_first=True, padding_value=self.config.blank_id)
|