import random, os, copy from typing import Dict, Iterator, List, Tuple, Union import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchmetrics.classification import MulticlassAccuracy import torch.distributed as dist from .modules.utils import make_pad_mask, generate_partial_autoregressive_mask from .modules.embedding import SinePositionalEmbedding, TokenEmbedding, SinePositionalEmbedding_progress from .modules.transformer import ( AdaptiveLayerNorm, LayerNorm, TransformerDecoderLayer, TransformerDecoder, TransformerEncoder, TransformerEncoderLayer, ) def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, min_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if min_p < 1.0: probs = F.softmax(logits, dim=-1) indices_to_remove = probs < min_p if not torch.any(indices_to_remove.sum(-1) == logits.size(-1)): logits[indices_to_remove] = filter_value top_k = 0 top_p = 1.0 # else will use other types of sampling, or no filtering # If top_k is a single integer if isinstance(top_k, int) and top_k > 0: # Safety check to ensure we don't ask for more than available top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Remove all tokens with a probability less than the last token of the top-k threshold = torch.topk(logits, top_k, dim=-1)[0][..., -1, None] indices_to_remove = logits < threshold logits[indices_to_remove] = filter_value # If top_k is a list, assume it has the same length as M elif isinstance(top_k, list): # Ensure the length matches the first dimension assert len(top_k) == logits.size(0), \ f"top_k list length ({len(top_k)}) must match logits.size(0) ({logits.size(0)})" for i in range(logits.size(0)): k_i = top_k[i] if k_i > 0: # Safety check k_i = min(max(k_i, min_tokens_to_keep), logits.size(-1)) row_threshold = torch.topk(logits[i], k_i, dim=-1)[0][-1] indices_to_remove_i = logits[i] < row_threshold logits[i, indices_to_remove_i] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum( F.softmax(sorted_logits, dim=-1), dim=-1 ) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1 ].clone() sorted_indices_to_remove[..., 0] = 0 return logits def topk_sampling(logits, top_k=10, top_p=1.0, min_p=1.0, temperature=1.0): # temperature: (`optional`) float # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. # top_k: (`optional`) int # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. # top_p: (`optional`) float # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: logits = logits / temperature # Top-p/top-k filtering logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p) # Sample token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) return token class VoiceStar(nn.Module): def __init__(self, args): super().__init__() self.args = args assert self.args.enc_dec ^ self.args.dec, f"self.args.enc_dec: {self.args.enc_dec}, self.args.dec: {self.args.dec}" if not getattr(self.args, "special_first", False): self.args.special_first = 0 if not getattr(self.args, "n_special", False): self.args.n_special = 3 self.args.eos = getattr(self.args, "eos", -1) self.eog = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), requires_grad=False) # [K 1] if self.args.eos > 0: assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1] if type(self.args.audio_vocab_size) == str: self.args.audio_vocab_size = eval(self.args.audio_vocab_size) if type(self.args.audio_vocab_size) == list: # otherwise they are all lists assert self.args.special_first self.n_text_tokens = self.args.text_vocab_size + 1 assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}" if self.args.special_first and type(self.args.audio_vocab_size) == list: self.n_audio_tokens = [tok + self.args.n_special for tok in self.args.audio_vocab_size] # special tokens: empty token, EOG token, audio pad token assert self.args.empty_token == 0, self.args.empty_token assert self.args.eog == 1, self.args.eog assert self.args.audio_pad_token == 2, self.args.audio_pad_token else: self.n_audio_tokens = [self.args.audio_vocab_size + self.args.n_special] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token self.text_embedding = TokenEmbedding( dim_model=self.args.d_model, vocab_size=self.n_text_tokens, dropout=self.args.text_embedding_dropout ) self.audio_embedding = nn.ModuleList( [ TokenEmbedding( dim_model=self.args.audio_embedding_dim, vocab_size=self.n_audio_tokens[k], dropout=self.args.audio_embedding_dropout ) for k in range(self.args.n_codebooks) ] ) rope_base = getattr(self.args, "rope_base", None) use_sinusoidal = getattr(self.args, "use_sinusoidal", False) use_sinusoidal_progress = getattr(self.args, "use_sinusoidal_progress", False) logging.info(f"rope_base: {rope_base}, use_sinusoidal: {use_sinusoidal}") if use_sinusoidal: self.text_positional_embedding = SinePositionalEmbedding( self.args.d_model, dropout=self.args.text_positional_embedding_dropout, scale=False, alpha=True, # learnable scaler, scale the volume of positional embedding ) self.audio_positional_embedding = SinePositionalEmbedding( self.args.d_model, dropout=self.args.audio_positional_embedding_dropout, scale=False, alpha=True, # learnable scaler, scale the volume of positional embedding ) elif use_sinusoidal_progress: self.text_positional_embedding = SinePositionalEmbedding_progress( self.args.d_model, dropout=self.args.text_positional_embedding_dropout, scale=False, alpha=True, # learnable scaler, scale the volume of positional embedding args = self.args ) self.audio_positional_embedding = SinePositionalEmbedding_progress( self.args.d_model, dropout=self.args.audio_positional_embedding_dropout, scale=False, alpha=True, # learnable scaler, scale the volume of positional embedding args = self.args ) else: class NoOp: def __init__(self): pass def __call__(self, *args, **kwargs): return args[0] self.text_positional_embedding = NoOp() self.audio_positional_embedding = NoOp() if self.args.enc_dec: enc_layer = TransformerEncoderLayer( d_model=self.args.d_model, nhead=self.args.nhead, dim_feedforward=self.args.d_model*4, dropout=self.args.trm_dropout, batch_first=True, norm_first=True, layer_norm_cls=LayerNorm ) # use the pre-norm arch self.encoder = TransformerEncoder( encoder_layer=enc_layer, num_layers=self.args.num_encoder_layers, norm=LayerNorm(self.args.d_model), rope_base = self.args.rope_base, d_model = self.args.d_model, nhead = self.args.nhead, args = self.args ) # use the pre-norm arch dec_layer = TransformerDecoderLayer( d_model=self.args.d_model, nhead=self.args.nhead, dim_feedforward=self.args.d_model*4, dropout=self.args.trm_dropout, batch_first=True, norm_first=True, layer_norm_cls=LayerNorm ) self.decoder = TransformerDecoder( decoder_layer=dec_layer, num_layers=self.args.num_decoder_layers, norm=LayerNorm(self.args.d_model), rope_base = self.args.rope_base, d_model = self.args.d_model, nhead = self.args.nhead, args = self.args ) # NOTE: this one I use torch.nn native implementation, as it's not implemented in .modules else: dec_layer = TransformerEncoderLayer( self.args.d_model, self.args.nhead, dim_feedforward=self.args.d_model * 4, dropout=self.args.trm_dropout, batch_first=True, norm_first=True, layer_norm_cls=LayerNorm ) self.decoder = TransformerEncoder( dec_layer, num_layers=self.args.num_decoder_layers, norm=LayerNorm(self.args.d_model), ) if type(self.args.audio_vocab_size) == int: self.predict_layer = nn.ModuleList( [ nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks) ] ) else: self.predict_layer = nn.ModuleList( [ nn.Sequential(nn.Linear(self.args.d_model, self.args.d_model//2), nn.GELU(), nn.Linear(self.args.d_model//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks) ] ) self.accuracy_metrics = nn.ModuleList( [MulticlassAccuracy( self.n_audio_tokens[k], top_k=10, average="micro", multidim_average="global", ignore_index=None, ) for k in range(self.args.n_codebooks)] ) if self.args.eog_weight != 1: raise NotImplementedError("now have different vocab_size for different codebooks, therefore currently don't support eog_weight") self.class_weight = nn.Parameter(torch.ones(self.n_audio_tokens), requires_grad=False) self.class_weight.data[self.args.eog] = self.args.eog_weight def dec_forward( self, x_input, x_lens, x_attention_mask, x_padding_mask, y_input, new_y_lens, y_attention_mask, y_padding_mask, need_weights=False, past=None, last_3_tokens=False ): x_attn_mask = F.pad( x_attention_mask, (0, new_y_lens.max()), value=True, ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper y_attn_mask = F.pad( y_attention_mask, (x_lens.max(), 0), # y is padded at the front value=False, ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) # merge key padding and attention masks bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max() xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1) _xy_padding_mask = ( xy_padding_mask.view(bsz, 1, 1, src_len) .expand(-1, self.args.nhead, -1, -1) .reshape(bsz * self.args.nhead, 1, src_len) ) xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) new_attn_mask = torch.zeros_like(xy_attn_mask) new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) xy_attn_mask = new_attn_mask xy_input = torch.cat([x_input, y_input], dim=1) if need_weights: raise NotImplementedError("not implemented yet") out, layer_attn_weights = self.decoder((xy_input, None), mask=xy_attn_mask, need_weights=True) return layer_attn_weights if past == None: # do not use kvcache out, _ = self.decoder((xy_input, None), mask=xy_attn_mask) return out[:, x_lens.max():], None else: # use kvcache if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet if last_3_tokens: xy_input = xy_input[:, -3:] xy_attn_mask = xy_attn_mask[:, -3:] else: xy_input = xy_input[:, -1:] xy_attn_mask = xy_attn_mask[:, -1:] out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past) if isinstance(out, tuple): # get rid of stage_embedding out = out[0] if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet return out[:, x_lens.max():], present else: # used kvcache return out, present def enc_dec_forward( self, xa, x_attention_mask, x_padding_mask, y_input, new_y_lens, y_attention_mask, y_padding_mask, tgt_y_lens=None, need_weights=False, past=None, last_3_tokens=False ): assert not need_weights if past != None and past.ndim > 3: y_input = y_input[:, -1:] y_attention_mask = y_attention_mask[-1:] yhat, present = self.decoder(tgt=y_input, memory=xa, tgt_mask=y_attention_mask, tgt_key_padding_mask=y_padding_mask, memory_key_padding_mask=x_padding_mask, query_lens=tgt_y_lens, past=past) return yhat, present def forward(self, batch, calc_loss = False): """ Args: x: A 2-D tensor of shape (N, S). x_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (N, K, T). where K is the number of codebooks y_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. """ x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"] if len(x) == 0: return None x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x y = y[...,:y_lens.max()] assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape assert y_lens.ndim == 1, y_lens.shape x_padding_mask = make_pad_mask(x_lens).to(x.device) x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device) x_input = self.text_embedding(x) x_input = self.text_positional_embedding(x_input, x_lens) y_with_eos = [torch.cat([item[:, :y_lens[i]], self.eos], dim=-1) for i, item in enumerate(y)] targets = y_with_eos # apply delayed stacking on y shifted_y = [] patterns = [] new_y_lens = [] if getattr(self, "empty_tokens", None) == None: self.empty_tokens = torch.full((self.args.n_codebooks, self.args.n_codebooks), self.args.empty_token, dtype=torch.long).to(y.device) # [K, K] for i in range(len(y)): tmp = torch.cat([y_with_eos[i], self.empty_tokens], dim=-1) # [K, T+n_codebooks] for ii in range(self.args.n_codebooks): tmp[ii] = torch.roll(tmp[ii], shifts=ii+1, dims=0) shifted_y.append(tmp.transpose(1,0)) # [K, T+n_codebooks] -> [T+n_codebooks, K] new_y_lens.append(y_with_eos[i].shape[1] + self.empty_tokens.shape[1]) new_y_lens = torch.LongTensor(new_y_lens).to(y.device) cated_y = torch.nn.utils.rnn.pad_sequence(shifted_y, batch_first=False, padding_value=self.args.audio_pad_token) assert cated_y.shape == torch.Size([max(new_y_lens), len(y), self.args.n_codebooks]), cated_y.shape cated_y = cated_y.permute(2,0,1) # [T,B,K]->[K,T,B] stacked_embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D] assert stacked_embedded_y.shape[0] == self.args.n_codebooks and stacked_embedded_y.shape[2] == len(y) and stacked_embedded_y.shape[-1] == self.args.d_model, stacked_embedded_y.shape embedded_y = stacked_embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D] embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D] assert embedded_y.shape[1:] == torch.Size([max(new_y_lens), self.args.d_model]), embedded_y.shape y_input = self.audio_positional_embedding(embedded_y, new_y_lens) y_padding_mask = make_pad_mask(new_y_lens).to(y.device) y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device) if self.args.dec: y_out = self.dec_forward( x_input, x_lens, x_attention_mask, x_padding_mask, y_input, new_y_lens, y_attention_mask, y_padding_mask ) else: xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask) y_out = self.enc_dec_forward( xa, x_attention_mask, x_padding_mask, y_input, new_y_lens, y_attention_mask, y_padding_mask ) y_out = y_out[0] # no kv-caching during training assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D] logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card] assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape logits_use = [logit[:, :new_y_lens[i]] for i, logit in enumerate(logits)] # each of shape [K, T, card] logits_final = [] for i, logit in enumerate(logits_use): logit_copy = logit.clone() for ii in range(self.args.n_codebooks): logit_copy[ii] = torch.roll(logit_copy[ii], shifts=-ii, dims=0) logit = logit_copy[:, :-self.args.n_codebooks] # [K, T, card] -> [K, T-n_codebooks, card] logits_final.append(logit) if self.args.no_loss_on_prefix: assert "y_sep_token_position" in batch, f"y_sep_token_position should be in batch, but it's not" logit_temp = [] target_temp = [] for jj, (logit, target) in enumerate(zip(logits_final, targets)): # TODO already taken into consideration in depth transformer logit_temp.append(logit[:, batch['y_sep_token_position'][jj]:]) target_temp.append(target[:, batch['y_sep_token_position'][jj]:]) logits_final = logit_temp targets = target_temp logits = torch.cat(logits_final, dim=1) # [K, T1+T2+T3+..., card] targets = torch.cat(targets, dim=1) # [K, T1+T2+T3+...] assert targets.shape[:2] == logits.shape[:2], f"{targets.shape}, {logits.shape}" loss = [] ntokens = [] top10acc = [] for k, (logit, target) in enumerate(zip(logits, targets)): # even though the loss and top10acc is calculated in a loop (loop through n_codebooks), validation is still taking a lot of mem, need to optimize this a little more loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None, ignore_index=self.args.y_sep_token if self.args.y_sep_token != None else -100)) # ignore audio sep token as it's unpredictable (like the random early stop bug happened in 2023) # NOTE have to ignore the sep token in the loss calculation top10acc.append(self.accuracy_metrics[k](logit.detach(), target)) ntokens.append(len(logit)) all_ntokens = sum(ntokens) if self.args.codebook_weight != None: codebook_weight = eval(self.args.codebook_weight) if isinstance(self.args.codebook_weight, str) else self.args.codebook_weight else: codebook_weight = [1.] * self.args.n_codebooks perplexity_by_codebook = [torch.exp(l) for l in loss] loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)]) top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)] top10acc = sum(top10acc_by_codebook) ntokens = torch.tensor(all_ntokens).to(logits.device) ret = { "loss": loss, "perplexity_by_codebook": perplexity_by_codebook, "top10acc": top10acc, "top10acc_by_codebook": top10acc_by_codebook, "effective_ntoken": ntokens, } return ret def inference_tts( self, x: torch.Tensor, x_lens: torch.Tensor, y: torch.Tensor, tgt_y_lens: torch.Tensor, # top_k: Union[int, list[int]]=-100, top_p: float=1.0, min_p: float=1.0, temperature: float=1.0, stop_repetition: int=3, kvcache: int=1, silence_tokens: list[int]=[], multi_trial: list[int]=[], *kargs ) -> torch.Tensor: """ different from inference_tts, this implementation uses kvcache, which should have significant speed up Args: x: A 2-D tensor of shape (1, L). x_lens: A 1-D tensor of shape (1,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (1, T, K). tgt_y_lens: *new arg* this specify the target length of y top_k: (`optional`) int The number of highest probability tokens to keep for top-k-filtering. Default to -100. top_p: (`optional`) float For Neucleus sampling min_p: (`optional`) float For min_p filtered sampling temperature: (`optional`) float The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. multi_trial: (`optional`) list[int] If not empty, it will be [n_trials, beam_size, trial_interval] from the start and begining trial_interval, we duplicate the current sample by beam_size, at the end of every trial_interval, we choose the sample with the highest log likelihood to keep and throw away the rest """ eog_inference = self.args.eos if self.args.eos>0 else self.args.eog assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3, y.shape if self.args.special_first: y = y + int(self.args.n_special) y = y.transpose(2,1) # [1,T,K] -> [1,K,T] assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding # make x attention mask and x_input x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device) # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device) x_input = self.text_embedding(x) x_input = self.text_positional_embedding(x_input, x_lens) y_len = y.shape[2] y_lens = torch.LongTensor([y_len]).to(y.device) # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario rearranged_y = [[y[0]]] assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape # # shift y to create the delayed pattern if getattr(self, "empty_tokens", None) == None: self.empty_tokens = torch.full((self.args.n_codebooks, self.args.n_codebooks), self.args.empty_token, dtype=torch.long).to(y.device) # [K, K] temp = rearranged_y[0][0] assert temp.ndim == 2 and temp.shape[0] == self.args.n_codebooks, temp.shape temp = torch.cat([temp, self.empty_tokens], dim=-1) # [K, T+n_codebooks] for ii in range(self.args.n_codebooks): temp[ii] = torch.roll(temp[ii], shifts=ii+1, dims=0) shifted_y = [[temp]] # below is different from forward or inference # where we cut this shifted part shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)] assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0] # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that # next section is concate tensors of each sample to one tensor, which we also don't need cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B] new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device) assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1)) assert not (cated_y == self.args.audio_pad_token).any(), cated_y # replace tokens in y with the embeddings, add sum codebooks up embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D] assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D] embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D] # positional embedding y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens) # make attention mask and padding mask y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device) y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) # entering the generation stage # starting from line 708 codebook_eog = [False] * self.args.n_codebooks generated = [] # doesn't contain any empty token, contain eog cur_generated = [] # say 0 is empty, 4 is eog # tensor([[ 1, 2, 3, 4, 0, 0], # [ 0, 1, 2, 3, 4, 0], # [ 0, 0, 1, 2, 3, 4]]) num_gen = [] cur_num_gen = 0 ##################### silence repetition handling ##################### ##################### silence repetition handling ##################### # silence_tokens = [1388,1898,131] # [1388, 2045, 2041, 1996] # silence_tokens = [] consec_silence_count = 0 prev_token = None ##################### silence repetition handling ##################### ##################### silence repetition handling ##################### def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, min_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen): if n_eog == 0: logits_adjust = logits for jj in range(1,self.args.n_codebooks): logits_adjust[jj][eog_inference] = -10000 logits_adjust[jj][self.args.empty_token] = -10000 if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early logits_adjust[0][eog_inference] = -10000 ##################### silence repetition handling ##################### if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition: if logits_adjust[0, prev_token] < 0: logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1)) else: logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1)) ##################### silence repetition handling ##################### samples = topk_sampling( logits_adjust, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature ) # [K, 1] assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}" if cur_num_gen < self.args.n_codebooks-1: for jj in range(1, self.args.n_codebooks - cur_num_gen): samples[-jj, 0] = self.args.empty_token if ( samples[0,0] == eog_inference or torch.argmax(logits[0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr//4) ) or self.args.rope_base is not None and not self.args.decoder_regular_rope and self.args.progress_no_multiple and cur_num_gen > (tgt_y_lens[0] + self.args.encodec_sr * getattr(self.args, "extra_cutoff", 5)): # last one condition in the first bracket means y is already too long, shouldn't happen, but put it here # the second bracket means we are using progress-monitoring RoPE, but the model is generating excessively long sequence (5 seconds more than specified), in which case we terminate the generation samples[0,0] = eog_inference codebook_eog[0] = True ##################### silence repetition handling ##################### if samples[0,0] in silence_tokens and samples[0,0] == prev_token: consec_silence_count += 1 else: consec_silence_count = 0 prev_token = samples[0,0] ##################### silence repetition handling ##################### return samples, codebook_eog, prev_token, consec_silence_count else: assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}" logits_adjust = logits for jj in range(n_eog+1,self.args.n_codebooks): logits_adjust[jj][eog_inference] = -10000 logits_adjust[jj][self.args.empty_token] = -10000 samples = topk_sampling( logits_adjust, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature ) # [K, 1] for jj in range(n_eog): samples[jj, 0] = self.args.empty_token samples[n_eog, 0] = eog_inference codebook_eog[n_eog] = True return samples, codebook_eog, prev_token, consec_silence_count # prepare the cache placeholder # n_layers, 2, bsz, num_heads, src_len, head_dim, 2 means [key, value] past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None if self.args.enc_dec: xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask) while True: if self.args.dec: y_out, present = self.dec_forward( x_input, x_lens, x_attention_mask, x_padding_mask, y_input, new_y_lens, y_attention_mask, y_padding_mask, past=past ) else: y_out, present = self.enc_dec_forward( xa, x_attention_mask, x_padding_mask, y_input, new_y_lens, y_attention_mask, y_padding_mask, tgt_y_lens=tgt_y_lens, past=past ) if past != None: past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype) y_out = y_out[:, -1:] # only take the last token logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card] logits = logits.squeeze(0).squeeze(1) # [K card] assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}" n_eog = sum(codebook_eog) assert n_eog < self.args.n_codebooks if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans for jj in range(self.args.n_codebooks): logits[jj][self.args.eog] = -10000. samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, min_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen) # samples.shape is [K,1] # ge samples_emb samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D] samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D] cur_num_gen += 1 cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K] if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done codebook_eog = [False] * self.args.n_codebooks num_gen.append(cur_num_gen) cur_num_gen = 0 generated.append(cur_generated) cur_generated = [] break else: assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}" embedded_y = torch.cat([embedded_y, samples_emb], dim=1) new_y_lens = torch.LongTensor([embedded_y.shape[1]]).to(y.device) y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens) # [B T D] # make attention mask and padding mask y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) assert len(generated) == 1, f"len(generated): {len(generated)}" # revert the pattern flatten_gen = [] for l, orig_span in enumerate(generated): span = torch.stack(orig_span, dim=0) # [T, K] span = span.transpose(1,0) # [K, T] assert span.shape[0] == self.args.n_codebooks, span.shape unshifted_span = [] for j, s in enumerate(span): start_from = j end_at = - (self.args.n_codebooks - start_from) unshifted_span.append(s[start_from:end_at]) unshifted_span = torch.stack(unshifted_span, dim=0) assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}" flatten_gen.append(unshifted_span) assert len(flatten_gen) == 1, len(flatten_gen) # combine res = [y[0], flatten_gen[0]] res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T] expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen]) assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}" if self.args.special_first: res = res - int(self.args.n_special) flatten_gen = flatten_gen - int(self.args.n_special) return res, flatten_gen[0].unsqueeze(0)