|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
We just merge all the required modules and functions into one python file. |
|
It is for easily use the pre-trained model to extract features. |
|
""" |
|
import math |
|
import numpy as np |
|
import logging |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import Parameter |
|
from torch import Tensor |
|
from typing import Any, Dict, List, Tuple, Callable, Optional |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def module_name_fordropout(module_name: str) -> str: |
|
if module_name == "TransformerEncoderBase": |
|
return "TransformerEncoder" |
|
else: |
|
return module_name |
|
|
|
def utils_make_positions(tensor, padding_idx: int, onnx_trace: bool = False): |
|
"""Replace non-padding symbols with their position numbers. |
|
|
|
Position numbers begin at padding_idx+1. Padding symbols are ignored. |
|
""" |
|
|
|
|
|
|
|
|
|
mask = tensor.ne(padding_idx).int() |
|
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx |
|
|
|
def utils_item(tensor): |
|
|
|
if torch.is_tensor(tensor) and tensor.device.type == "xla": |
|
return tensor.detach() |
|
if hasattr(tensor, "item"): |
|
return tensor.item() |
|
if hasattr(tensor, "__getitem__"): |
|
return tensor[0] |
|
return tensor |
|
|
|
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): |
|
""" |
|
Helper to wrap layers/modules in FSDP. This falls back to a no-op if |
|
fairscale is not available. |
|
|
|
Args: |
|
module (nn.Module): module to (maybe) wrap |
|
min_num_params (int, Optional): minimum number of layer params to wrap |
|
""" |
|
try: |
|
from fairscale.nn import wrap |
|
|
|
if min_num_params is not None: |
|
num_params = sum(p.numel() for p in module.parameters()) |
|
if num_params >= min_num_params: |
|
return wrap(module, **kwargs) |
|
else: |
|
return module |
|
else: |
|
return wrap(module, **kwargs) |
|
except ImportError: |
|
return module |
|
|
|
def quant_noise(module, p, block_size): |
|
""" |
|
Wraps modules and applies quantization noise to the weights for |
|
subsequent quantization with Iterative Product Quantization as |
|
described in "Training with Quantization Noise for Extreme Model Compression" |
|
|
|
Args: |
|
- module: nn.Module |
|
- p: amount of Quantization Noise |
|
- block_size: size of the blocks for subsequent quantization with iPQ |
|
|
|
Remarks: |
|
- Module weights must have the right sizes wrt the block size |
|
- Only Linear, Embedding and Conv2d modules are supported for the moment |
|
- For more detail on how to quantize by blocks with convolutional weights, |
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" |
|
- We implement the simplest form of noise here as stated in the paper |
|
which consists in randomly dropping blocks |
|
""" |
|
|
|
|
|
if p <= 0: |
|
return module |
|
|
|
|
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) |
|
|
|
|
|
is_conv = module.weight.ndim == 4 |
|
|
|
|
|
if not is_conv: |
|
assert ( |
|
module.weight.size(1) % block_size == 0 |
|
), "Input features must be a multiple of block sizes" |
|
|
|
|
|
else: |
|
|
|
if module.kernel_size == (1, 1): |
|
assert ( |
|
module.in_channels % block_size == 0 |
|
), "Input channels must be a multiple of block sizes" |
|
|
|
else: |
|
k = module.kernel_size[0] * module.kernel_size[1] |
|
assert k % block_size == 0, "Kernel size must be a multiple of block size" |
|
|
|
def _forward_pre_hook(mod, input): |
|
|
|
if mod.training: |
|
if not is_conv: |
|
|
|
weight = mod.weight |
|
in_features = weight.size(1) |
|
out_features = weight.size(0) |
|
|
|
|
|
mask = torch.zeros( |
|
in_features // block_size * out_features, device=weight.device |
|
) |
|
mask.bernoulli_(p) |
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) |
|
|
|
else: |
|
|
|
weight = mod.weight |
|
in_channels = mod.in_channels |
|
out_channels = mod.out_channels |
|
|
|
|
|
if mod.kernel_size == (1, 1): |
|
mask = torch.zeros( |
|
int(in_channels // block_size * out_channels), |
|
device=weight.device, |
|
) |
|
mask.bernoulli_(p) |
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) |
|
else: |
|
mask = torch.zeros( |
|
weight.size(0), weight.size(1), device=weight.device |
|
) |
|
mask.bernoulli_(p) |
|
mask = ( |
|
mask.unsqueeze(2) |
|
.unsqueeze(3) |
|
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) |
|
) |
|
|
|
|
|
mask = mask.to( |
|
torch.bool |
|
) |
|
s = 1 / (1 - p) |
|
mod.weight.data = s * weight.masked_fill(mask, 0) |
|
|
|
module.register_forward_pre_hook(_forward_pre_hook) |
|
return module |
|
|
|
def relu_squared(x: torch.Tensor): |
|
return F.relu(x).pow(2) |
|
|
|
def gelu(x: torch.Tensor) -> torch.Tensor: |
|
return torch.nn.functional.gelu(x.float()).type_as(x) |
|
|
|
def gelu_accurate(x): |
|
if not hasattr(gelu_accurate, "_a"): |
|
gelu_accurate._a = math.sqrt(2 / math.pi) |
|
return ( |
|
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) |
|
) |
|
|
|
def get_activation_fn(activation: str) -> Callable: |
|
"""Returns the activation function corresponding to `activation`""" |
|
if activation == "relu": |
|
return F.relu |
|
elif activation == "relu_squared": |
|
return relu_squared |
|
elif activation == "gelu": |
|
return gelu |
|
elif activation == "gelu_fast": |
|
logger.warn( |
|
"--activation-fn=gelu_fast has been renamed to gelu_accurate" |
|
) |
|
return gelu_accurate |
|
elif activation == "gelu_accurate": |
|
return gelu_accurate |
|
elif activation == "tanh": |
|
return torch.tanh |
|
elif activation == "linear": |
|
return lambda x: x |
|
elif activation == "swish": |
|
return torch.nn.SiLU |
|
else: |
|
raise RuntimeError("--activation-fn {} not supported".format(activation)) |
|
|
|
def softmax(x, dim: int, onnx_trace: bool = False): |
|
if onnx_trace: |
|
return F.softmax(x.float(), dim=dim) |
|
else: |
|
return F.softmax(x, dim=dim, dtype=torch.float32) |
|
|
|
def compute_mask_indices( |
|
shape: Tuple[int, int], |
|
padding_mask: Optional[torch.Tensor], |
|
mask_prob: float, |
|
mask_length: int, |
|
mask_type: str = "static", |
|
mask_other: float = 0.0, |
|
min_masks: int = 0, |
|
no_overlap: bool = False, |
|
min_space: int = 0, |
|
require_same_masks: bool = True, |
|
mask_dropout: float = 0.0, |
|
) -> np.ndarray: |
|
""" |
|
Computes random mask spans for a given shape |
|
|
|
Args: |
|
shape: the the shape for which to compute masks. |
|
should be of size 2 where first element is batch size and 2nd is timesteps |
|
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements |
|
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by |
|
number of timesteps divided by length of mask span to mask approximately this percentage of all elements. |
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True) |
|
mask_type: how to compute mask lengths |
|
static = fixed size |
|
uniform = sample from uniform distribution [mask_other, mask_length*2] |
|
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element |
|
poisson = sample from possion distribution with lambda = mask length |
|
min_masks: minimum number of masked spans |
|
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping |
|
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans |
|
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample |
|
mask_dropout: randomly dropout this percentage of masks in each example |
|
""" |
|
|
|
bsz, all_sz = shape |
|
mask = np.full((bsz, all_sz), False) |
|
|
|
all_num_mask = int( |
|
|
|
mask_prob * all_sz / float(mask_length) |
|
+ np.random.rand() |
|
) |
|
|
|
all_num_mask = max(min_masks, all_num_mask) |
|
|
|
mask_idcs = [] |
|
for i in range(bsz): |
|
if padding_mask is not None: |
|
sz = all_sz - padding_mask[i].long().sum().item() |
|
num_mask = int( |
|
|
|
mask_prob * sz / float(mask_length) |
|
+ np.random.rand() |
|
) |
|
num_mask = max(min_masks, num_mask) |
|
else: |
|
sz = all_sz |
|
num_mask = all_num_mask |
|
|
|
if mask_type == "static": |
|
lengths = np.full(num_mask, mask_length) |
|
elif mask_type == "uniform": |
|
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) |
|
elif mask_type == "normal": |
|
lengths = np.random.normal(mask_length, mask_other, size=num_mask) |
|
lengths = [max(1, int(round(x))) for x in lengths] |
|
elif mask_type == "poisson": |
|
lengths = np.random.poisson(mask_length, size=num_mask) |
|
lengths = [int(round(x)) for x in lengths] |
|
else: |
|
raise Exception("unknown mask selection " + mask_type) |
|
|
|
if sum(lengths) == 0: |
|
lengths[0] = min(mask_length, sz - 1) |
|
|
|
if no_overlap: |
|
mask_idc = [] |
|
|
|
def arrange(s, e, length, keep_length): |
|
span_start = np.random.randint(s, e - length) |
|
mask_idc.extend(span_start + i for i in range(length)) |
|
|
|
new_parts = [] |
|
if span_start - s - min_space >= keep_length: |
|
new_parts.append((s, span_start - min_space + 1)) |
|
if e - span_start - keep_length - min_space > keep_length: |
|
new_parts.append((span_start + length + min_space, e)) |
|
return new_parts |
|
|
|
parts = [(0, sz)] |
|
min_length = min(lengths) |
|
for length in sorted(lengths, reverse=True): |
|
lens = np.fromiter( |
|
(e - s if e - s >= length + min_space else 0 for s, e in parts), |
|
np.int, |
|
) |
|
l_sum = np.sum(lens) |
|
if l_sum == 0: |
|
break |
|
probs = lens / np.sum(lens) |
|
c = np.random.choice(len(parts), p=probs) |
|
s, e = parts.pop(c) |
|
parts.extend(arrange(s, e, length, min_length)) |
|
mask_idc = np.asarray(mask_idc) |
|
else: |
|
min_len = min(lengths) |
|
if sz - min_len <= num_mask: |
|
min_len = sz - num_mask - 1 |
|
|
|
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) |
|
|
|
mask_idc = np.asarray( |
|
[ |
|
mask_idc[j] + offset |
|
for j in range(len(mask_idc)) |
|
for offset in range(lengths[j]) |
|
] |
|
) |
|
|
|
mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) |
|
|
|
min_len = min([len(m) for m in mask_idcs]) |
|
for i, mask_idc in enumerate(mask_idcs): |
|
if len(mask_idc) > min_len and require_same_masks: |
|
mask_idc = np.random.choice(mask_idc, min_len, replace=False) |
|
if mask_dropout > 0: |
|
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) |
|
mask_idc = np.random.choice( |
|
mask_idc, len(mask_idc) - num_holes, replace=False |
|
) |
|
|
|
mask[i, mask_idc] = True |
|
|
|
return mask |
|
|
|
def init_bert_params(module): |
|
""" |
|
Initialize the weights specific to the BERT Model. |
|
This overrides the default initializations depending on the specified arguments. |
|
1. If normal_init_linear_weights is set then weights of linear |
|
layer will be initialized using the normal distribution and |
|
bais will be set to the specified value. |
|
2. If normal_init_embed_weights is set then weights of embedding |
|
layer will be initialized using the normal distribution. |
|
3. If normal_init_proj_weights is set then weights of |
|
in_project_weight for MultiHeadAttention initialized using |
|
the normal distribution (to be validated). |
|
""" |
|
|
|
def normal_(data): |
|
|
|
|
|
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) |
|
|
|
if isinstance(module, nn.Linear): |
|
normal_(module.weight.data) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
if isinstance(module, nn.Embedding): |
|
normal_(module.weight.data) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
if isinstance(module, MultiheadAttention): |
|
normal_(module.q_proj.weight.data) |
|
normal_(module.k_proj.weight.data) |
|
normal_(module.v_proj.weight.data) |
|
|
|
def pad_to_multiple(x, multiple, dim=-1, value=0): |
|
|
|
if x is None: |
|
return None, 0 |
|
tsz = x.size(dim) |
|
m = tsz / multiple |
|
remainder = math.ceil(m) * multiple - tsz |
|
if m.is_integer(): |
|
return x, 0 |
|
pad_offset = (0,) * (-1 - dim) * 2 |
|
|
|
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder |
|
|
|
def is_xla_tensor(tensor): |
|
return torch.is_tensor(tensor) and tensor.device.type == "xla" |
|
|
|
def index_put(tensor, indices, value): |
|
if is_xla_tensor(tensor): |
|
for _ in range(indices.dim(), tensor.dim()): |
|
indices = indices.unsqueeze(-1) |
|
if indices.size(-1) < tensor.size(-1): |
|
indices = indices.expand_as(tensor) |
|
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) |
|
else: |
|
tensor[indices] = value |
|
return tensor |
|
|
|
def PositionalEmbedding( |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
padding_idx: int, |
|
learned: bool = False, |
|
): |
|
if learned: |
|
|
|
|
|
|
|
|
|
if padding_idx is not None: |
|
num_embeddings = num_embeddings + padding_idx + 1 |
|
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) |
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) |
|
if padding_idx is not None: |
|
nn.init.constant_(m.weight[padding_idx], 0) |
|
else: |
|
m = SinusoidalPositionalEmbedding( |
|
embedding_dim, |
|
padding_idx, |
|
init_size=num_embeddings + padding_idx + 1, |
|
) |
|
return m |
|
|
|
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): |
|
if torch.jit.is_scripting() or torch.jit.is_tracing(): |
|
export = True |
|
if not export and torch.cuda.is_available() and has_fused_layernorm: |
|
return FusedLayerNorm(normalized_shape, eps, elementwise_affine) |
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
|
|
|
|
|
class TransformerEncoderBase(nn.Module): |
|
""" |
|
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer |
|
is a :class:`TransformerEncoderLayer`. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
dictionary: deprecated(None) |
|
embed_tokens (torch.nn.Embedding): input embedding |
|
""" |
|
|
|
def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0): |
|
self.cfg = cfg |
|
super().__init__() |
|
self.register_buffer("version", torch.Tensor([3])) |
|
|
|
self.dropout_module = FairseqDropout( |
|
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) |
|
) |
|
self.encoder_layerdrop = cfg.encoder.layerdrop |
|
|
|
embed_dim = embed_tokens.embedding_dim if embed_tokens is not None else cfg.encoder.embed_dim |
|
self.padding_idx = embed_tokens.padding_idx if embed_tokens is not None else 1 |
|
self.max_source_positions = cfg.max_source_positions |
|
|
|
self.embed_tokens = embed_tokens |
|
|
|
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) |
|
|
|
self.embed_positions = ( |
|
PositionalEmbedding( |
|
cfg.max_source_positions, |
|
embed_dim, |
|
self.padding_idx, |
|
learned=cfg.encoder.learned_pos, |
|
) |
|
if not cfg.no_token_positional_embeddings |
|
else None |
|
) |
|
if cfg.layernorm_embedding: |
|
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) |
|
else: |
|
self.layernorm_embedding = None |
|
|
|
if not cfg.adaptive_input and cfg.quant_noise.pq > 0: |
|
self.quant_noise = quant_noise( |
|
nn.Linear(embed_dim, embed_dim, bias=False), |
|
cfg.quant_noise.pq, |
|
cfg.quant_noise.pq_block_size, |
|
) |
|
else: |
|
self.quant_noise = None |
|
|
|
if self.encoder_layerdrop > 0.0: |
|
self.layers = LayerDropModuleList(p=self.encoder_layerdrop) |
|
else: |
|
self.layers = nn.ModuleList([]) |
|
self.use_rel_pos_enc = use_rel_pos_enc |
|
self.scaling_for_att = scaling_for_att |
|
self.layers.extend( |
|
[self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] |
|
) |
|
self.num_layers = len(self.layers) |
|
|
|
if cfg.encoder.normalize_before: |
|
self.layer_norm = LayerNorm(embed_dim, export=cfg.export) |
|
else: |
|
self.layer_norm = None |
|
if self.use_rel_pos_enc: |
|
self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160) |
|
|
|
def build_encoder_layer(self, cfg): |
|
layer = TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att) |
|
checkpoint = cfg.checkpoint_activations |
|
if checkpoint: |
|
raise ValueError("We don't support checkpoint_activations for now! Please set cfg.checkpoint_activations=False.") |
|
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 |
|
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) |
|
return layer |
|
|
|
def forward_embedding( |
|
self, src_tokens, token_embedding: Optional[torch.Tensor] = None |
|
): |
|
|
|
if token_embedding is None: |
|
token_embedding = self.embed_tokens(src_tokens) |
|
x = embed = self.embed_scale * token_embedding |
|
if self.embed_positions is not None: |
|
x = embed + self.embed_positions(src_tokens) |
|
if self.layernorm_embedding is not None: |
|
x = self.layernorm_embedding(x) |
|
x = self.dropout_module(x) |
|
if self.quant_noise is not None: |
|
x = self.quant_noise(x) |
|
return x, embed |
|
|
|
def forward( |
|
self, |
|
src_tokens, |
|
src_lengths: Optional[torch.Tensor] = None, |
|
return_all_hiddens: bool = False, |
|
token_embeddings: Optional[torch.Tensor] = None, |
|
uniformity_layers: Optional[List[int]] = None, |
|
): |
|
""" |
|
Args: |
|
src_tokens (LongTensor): tokens in the source language of shape |
|
`(batch, src_len)` |
|
src_lengths (torch.LongTensor): lengths of each source sentence of |
|
shape `(batch)` |
|
return_all_hiddens (bool, optional): also return all of the |
|
intermediate hidden states (default: False). |
|
token_embeddings (torch.Tensor, optional): precomputed embeddings |
|
default `None` will recompute embeddings |
|
|
|
Returns: |
|
dict: |
|
- **encoder_out** (Tensor): the last encoder layer's output of |
|
shape `(src_len, batch, embed_dim)` |
|
- **encoder_padding_mask** (ByteTensor): the positions of |
|
padding elements of shape `(batch, src_len)` |
|
- **encoder_embedding** (Tensor): the (scaled) embedding lookup |
|
of shape `(batch, src_len, embed_dim)` |
|
- **encoder_states** (List[Tensor]): all intermediate |
|
hidden states of shape `(src_len, batch, embed_dim)`. |
|
Only populated if *return_all_hiddens* is True. |
|
""" |
|
return self.forward_scriptable( |
|
src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def forward_scriptable( |
|
self, |
|
src_tokens, |
|
src_lengths: Optional[torch.Tensor] = None, |
|
return_all_hiddens: bool = False, |
|
token_embeddings: Optional[torch.Tensor] = None, |
|
uniformity_layers: Optional[List[int]] = None, |
|
): |
|
""" |
|
Args: |
|
src_tokens (LongTensor): tokens in the source language of shape |
|
`(batch, src_len)` |
|
src_lengths (torch.LongTensor): lengths of each source sentence of |
|
shape `(batch)` |
|
return_all_hiddens (bool, optional): also return all of the |
|
intermediate hidden states (default: False). |
|
token_embeddings (torch.Tensor, optional): precomputed embeddings |
|
default `None` will recompute embeddings |
|
|
|
Returns: |
|
dict: |
|
- **encoder_out** (Tensor): the last encoder layer's output of |
|
shape `(src_len, batch, embed_dim)` |
|
- **encoder_padding_mask** (ByteTensor): the positions of |
|
padding elements of shape `(batch, src_len)` |
|
- **encoder_embedding** (Tensor): the (scaled) embedding lookup |
|
of shape `(batch, src_len, embed_dim)` |
|
- **encoder_states** (List[Tensor]): all intermediate |
|
hidden states of shape `(src_len, batch, embed_dim)`. |
|
Only populated if *return_all_hiddens* is True. |
|
""" |
|
|
|
encoder_padding_mask = src_tokens.eq(self.padding_idx) |
|
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() |
|
|
|
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) |
|
|
|
|
|
if has_pads: |
|
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
if self.use_rel_pos_enc: |
|
x_len = x.shape[0] |
|
pos_seq = torch.arange(0, x_len).long().to(x.device) |
|
pos_seq = pos_seq[:, None] - pos_seq[None, :] |
|
pos_k, pos_v = self.pos_emb(pos_seq) |
|
else: |
|
pos_k = None |
|
|
|
encoder_states = [] |
|
uniformity_hiddens = [] |
|
|
|
if return_all_hiddens: |
|
encoder_states.append(x) |
|
|
|
if uniformity_layers is not None and 0 in uniformity_layers: |
|
x = F.normalize(x.float(), dim=-1).type_as(x) |
|
uniformity_hiddens.append(x) |
|
|
|
|
|
for i, layer in enumerate(self.layers): |
|
x = layer( |
|
x, encoder_padding_mask=encoder_padding_mask if has_pads else None, |
|
pos_bias=pos_k, |
|
) |
|
if uniformity_layers is not None and i+1 in uniformity_layers: |
|
x = F.normalize(x.float(), dim=-1).type_as(x) |
|
uniformity_hiddens.append(x) |
|
if return_all_hiddens: |
|
assert encoder_states is not None |
|
encoder_states.append(x) |
|
|
|
if self.layer_norm is not None: |
|
x = self.layer_norm(x) |
|
|
|
|
|
|
|
|
|
|
|
src_lengths = ( |
|
src_tokens.ne(self.padding_idx) |
|
.sum(dim=1, dtype=torch.int32) |
|
.reshape(-1, 1) |
|
.contiguous() |
|
) |
|
return { |
|
"encoder_out": [x], |
|
"encoder_padding_mask": [encoder_padding_mask], |
|
"encoder_embedding": [encoder_embedding], |
|
"encoder_states": encoder_states, |
|
"uniformity_hiddens": uniformity_hiddens, |
|
"src_tokens": [], |
|
"src_lengths": [src_lengths], |
|
} |
|
|
|
def forward_torchscript(self, net_input: Dict[str, Tensor]): |
|
"""A TorchScript-compatible version of forward. |
|
|
|
Encoders which use additional arguments may want to override |
|
this method for TorchScript compatibility. |
|
""" |
|
if torch.jit.is_scripting(): |
|
return self.forward( |
|
src_tokens=net_input["src_tokens"], |
|
src_lengths=net_input["src_lengths"], |
|
) |
|
else: |
|
return self.forward_non_torchscript(net_input) |
|
|
|
@torch.jit.unused |
|
def forward_non_torchscript(self, net_input: Dict[str, Tensor]): |
|
encoder_input = { |
|
k: v for k, v in net_input.items() if k != "prev_output_tokens" |
|
} |
|
return self.forward(**encoder_input) |
|
|
|
@torch.jit.export |
|
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): |
|
""" |
|
Reorder encoder output according to *new_order*. |
|
|
|
Args: |
|
encoder_out: output from the ``forward()`` method |
|
new_order (LongTensor): desired order |
|
|
|
Returns: |
|
*encoder_out* rearranged according to *new_order* |
|
""" |
|
if len(encoder_out["encoder_out"]) == 0: |
|
new_encoder_out = [] |
|
else: |
|
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] |
|
if len(encoder_out["encoder_padding_mask"]) == 0: |
|
new_encoder_padding_mask = [] |
|
else: |
|
new_encoder_padding_mask = [ |
|
encoder_out["encoder_padding_mask"][0].index_select(0, new_order) |
|
] |
|
if len(encoder_out["encoder_embedding"]) == 0: |
|
new_encoder_embedding = [] |
|
else: |
|
new_encoder_embedding = [ |
|
encoder_out["encoder_embedding"][0].index_select(0, new_order) |
|
] |
|
|
|
if len(encoder_out["src_tokens"]) == 0: |
|
src_tokens = [] |
|
else: |
|
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] |
|
|
|
if len(encoder_out["src_lengths"]) == 0: |
|
src_lengths = [] |
|
else: |
|
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] |
|
|
|
encoder_states = encoder_out["encoder_states"] |
|
if len(encoder_states) > 0: |
|
for idx, state in enumerate(encoder_states): |
|
encoder_states[idx] = state.index_select(1, new_order) |
|
|
|
return { |
|
"encoder_out": new_encoder_out, |
|
"encoder_padding_mask": new_encoder_padding_mask, |
|
"encoder_embedding": new_encoder_embedding, |
|
"encoder_states": encoder_states, |
|
"src_tokens": src_tokens, |
|
"src_lengths": src_lengths, |
|
} |
|
|
|
def max_positions(self): |
|
"""Maximum input length supported by the encoder.""" |
|
if self.embed_positions is None: |
|
return self.max_source_positions |
|
return min(self.max_source_positions, self.embed_positions.max_positions) |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
"""Upgrade a (possibly old) state dict for new versions.""" |
|
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): |
|
weights_key = "{}.embed_positions.weights".format(name) |
|
if weights_key in state_dict: |
|
print("deleting {0}".format(weights_key)) |
|
del state_dict[weights_key] |
|
state_dict[ |
|
"{}.embed_positions._float_tensor".format(name) |
|
] = torch.FloatTensor(1) |
|
for i in range(self.num_layers): |
|
|
|
self.layers[i].upgrade_state_dict_named( |
|
state_dict, "{}.layers.{}".format(name, i) |
|
) |
|
|
|
version_key = "{}.version".format(name) |
|
if utils_item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: |
|
|
|
self.layer_norm = None |
|
self.normalize = False |
|
state_dict[version_key] = torch.Tensor([1]) |
|
return state_dict |
|
|
|
def set_num_updates(self, num_updates): |
|
"""State from trainer to pass along to model at every update.""" |
|
|
|
def _apply(m): |
|
if hasattr(m, "set_num_updates") and m != self: |
|
m.set_num_updates(num_updates) |
|
|
|
self.apply(_apply) |
|
|
|
|
|
class TransformerEncoderLayerBase(nn.Module): |
|
"""Encoder layer block. |
|
|
|
In the original paper each operation (multi-head attention or FFN) is |
|
postprocessed with: `dropout -> add residual -> layernorm`. In the |
|
tensor2tensor code they suggest that learning is more robust when |
|
preprocessing each layer with layernorm and postprocessing with: |
|
`dropout -> add residual`. We default to the approach in the paper, but the |
|
tensor2tensor approach can be enabled by setting |
|
*cfg.encoder.normalize_before* to ``True``. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
""" |
|
|
|
def __init__(self, cfg, has_relative_attention_bias=False, scaling_for_att=1.0): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.embed_dim = cfg.encoder.embed_dim |
|
self.quant_noise = cfg.quant_noise.pq |
|
self.quant_noise_block_size = cfg.quant_noise.pq_block_size |
|
self.self_attn = self.build_self_attention(self.embed_dim, cfg, has_relative_attention_bias=has_relative_attention_bias, scaling_for_att=scaling_for_att) |
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) |
|
self.dropout_module = FairseqDropout( |
|
cfg.dropout, module_name=self.__class__.__name__ |
|
) |
|
self.activation_fn = get_activation_fn(activation=cfg.activation_fn) |
|
activation_dropout_p = cfg.activation_dropout |
|
if activation_dropout_p == 0: |
|
|
|
activation_dropout_p = cfg.relu_dropout or 0 |
|
self.activation_dropout_module = FairseqDropout( |
|
float(activation_dropout_p), module_name=self.__class__.__name__ |
|
) |
|
self.normalize_before = cfg.encoder.normalize_before |
|
self.fc1 = self.build_fc1( |
|
self.embed_dim, |
|
cfg.encoder.ffn_embed_dim, |
|
self.quant_noise, |
|
self.quant_noise_block_size, |
|
) |
|
self.fc2 = self.build_fc2( |
|
cfg.encoder.ffn_embed_dim, |
|
self.embed_dim, |
|
self.quant_noise, |
|
self.quant_noise_block_size, |
|
) |
|
|
|
self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) |
|
if has_relative_attention_bias: |
|
self.norm_k = LayerNorm(self.embed_dim // cfg.encoder.attention_heads) |
|
|
|
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): |
|
return quant_noise( |
|
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size |
|
) |
|
|
|
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): |
|
return quant_noise( |
|
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size |
|
) |
|
|
|
def build_self_attention(self, embed_dim, cfg, has_relative_attention_bias=False, scaling_for_att=1.0): |
|
return MultiheadAttention( |
|
embed_dim, |
|
cfg.encoder.attention_heads, |
|
dropout=cfg.attention_dropout, |
|
self_attention=True, |
|
q_noise=self.quant_noise, |
|
qn_block_size=self.quant_noise_block_size, |
|
has_relative_attention_bias=has_relative_attention_bias, |
|
scaling_for_att=scaling_for_att, |
|
) |
|
|
|
def residual_connection(self, x, residual): |
|
return residual + x |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
""" |
|
Rename layer norm states from `...layer_norms.0.weight` to |
|
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to |
|
`...final_layer_norm.weight` |
|
""" |
|
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} |
|
for old, new in layer_norm_map.items(): |
|
for m in ("weight", "bias"): |
|
k = "{}.layer_norms.{}.{}".format(name, old, m) |
|
if k in state_dict: |
|
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] |
|
del state_dict[k] |
|
|
|
def forward( |
|
self, |
|
x, |
|
encoder_padding_mask: Optional[Tensor], |
|
attn_mask: Optional[Tensor] = None, |
|
pos_bias=None, |
|
): |
|
""" |
|
Args: |
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
encoder_padding_mask (ByteTensor): binary ByteTensor of shape |
|
`(batch, seq_len)` where padding elements are indicated by ``1``. |
|
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, |
|
where `tgt_len` is the length of output and `src_len` is the |
|
length of input, though here both are equal to `seq_len`. |
|
`attn_mask[tgt_i, src_j] = 1` means that when calculating the |
|
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is |
|
useful for strided self-attention. |
|
|
|
Returns: |
|
encoded output of shape `(seq_len, batch, embed_dim)` |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if attn_mask is not None: |
|
attn_mask = attn_mask.masked_fill( |
|
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 |
|
) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.self_attn_layer_norm(x) |
|
if pos_bias is not None: |
|
pos_bias = self.norm_k(pos_bias) |
|
x, _ = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=encoder_padding_mask, |
|
need_weights=False, |
|
attn_mask=attn_mask, |
|
position_bias=pos_bias, |
|
) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.self_attn_layer_norm(x) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.final_layer_norm(x) |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.activation_dropout_module(x) |
|
x = self.fc2(x) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.final_layer_norm(x) |
|
return x |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
""" |
|
wav2vec-style transformer encoder. |
|
""" |
|
def __init__(self, args): |
|
super().__init__() |
|
|
|
self.dropout = args.dropout |
|
self.embedding_dim = args.encoder_embed_dim |
|
self.required_seq_len_multiple = args.required_seq_len_multiple |
|
|
|
self.pos_conv = nn.Conv1d( |
|
self.embedding_dim, |
|
self.embedding_dim, |
|
kernel_size=args.conv_pos, |
|
padding=args.conv_pos // 2, |
|
groups=args.conv_pos_groups, |
|
) |
|
dropout = 0 |
|
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) |
|
nn.init.normal_(self.pos_conv.weight, mean=0, std=std) |
|
nn.init.constant_(self.pos_conv.bias, 0) |
|
|
|
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) |
|
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) |
|
|
|
layers = [] |
|
self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False) |
|
for _ in range(args.encoder_layers): |
|
layer = TransformerSentenceEncoderLayer( |
|
embedding_dim=self.embedding_dim, |
|
ffn_embedding_dim=args.encoder_ffn_embed_dim, |
|
num_attention_heads=args.encoder_attention_heads, |
|
dropout=self.dropout, |
|
attention_dropout=args.attention_dropout, |
|
activation_dropout=args.activation_dropout, |
|
activation_fn=args.activation_fn, |
|
layer_norm_first=args.layer_norm_first, |
|
has_relative_attention_bias=self.use_rel_pos_enc, |
|
scaling_for_att=getattr(args, "scaling_for_att", 1.0) |
|
) |
|
if args.checkpoint_activations: |
|
raise ValueError("We don't support checkpoint_activations for now! Please set checkpoint_activations=False.") |
|
layers.append(layer) |
|
self.layers = nn.ModuleList(layers) |
|
|
|
self.layer_norm_first = args.layer_norm_first |
|
self.layer_norm = LayerNorm(self.embedding_dim) |
|
self.layerdrop = args.encoder_layerdrop |
|
|
|
if self.use_rel_pos_enc: |
|
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160) |
|
|
|
self.apply(init_bert_params) |
|
|
|
def forward(self, x, padding_mask=None, layer=None, conv_pos=True): |
|
x, layer_results = self.extract_features(x, padding_mask, layer, conv_pos) |
|
|
|
if self.layer_norm_first and (layer is None or layer >= len(self.layers) - 1): |
|
x = self.layer_norm(x) |
|
|
|
return x, layer_results |
|
|
|
def extract_features(self, x, padding_mask=None, tgt_layer=None, conv_pos=True): |
|
|
|
if padding_mask is not None: |
|
x = index_put(x, padding_mask, 0) |
|
|
|
if conv_pos: |
|
x_conv = self.pos_conv(x.transpose(1, 2)) |
|
x_conv = x_conv.transpose(1, 2) |
|
x = x + x_conv |
|
|
|
if not self.layer_norm_first: |
|
x = self.layer_norm(x) |
|
|
|
|
|
x, pad_length = pad_to_multiple( |
|
x, self.required_seq_len_multiple, dim=-2, value=0 |
|
) |
|
if pad_length > 0 and padding_mask is None: |
|
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) |
|
padding_mask[:, -pad_length:] = True |
|
else: |
|
padding_mask, _ = pad_to_multiple( |
|
padding_mask, self.required_seq_len_multiple, dim=-1, value=True |
|
) |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
if self.use_rel_pos_enc: |
|
x_len = x.shape[0] |
|
pos_seq = torch.arange(0, x_len).long().to(x.device) |
|
pos_seq = pos_seq[:, None] - pos_seq[None, :] |
|
pos_k, pos_v = self.pos_emb(pos_seq) |
|
else: |
|
pos_k = None |
|
|
|
layer_results = [] |
|
r = None |
|
for i, layer in enumerate(self.layers): |
|
dropout_probability = np.random.random() |
|
if not self.training or (dropout_probability > self.layerdrop): |
|
x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k) |
|
if tgt_layer is not None: |
|
|
|
if pad_length > 0: |
|
layer_results.append( |
|
x[:-pad_length] |
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
else: |
|
|
|
layer_results.append(x) |
|
if i == tgt_layer: |
|
r = x |
|
break |
|
|
|
if r is not None: |
|
x = r |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
if pad_length > 0: |
|
x = x[:, :-pad_length] |
|
|
|
return x, layer_results |
|
|
|
def max_positions(self): |
|
"""Maximum output length supported by the encoder.""" |
|
return self.args.max_positions |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" |
|
return state_dict |
|
|
|
|
|
class TransformerSentenceEncoderLayer(nn.Module): |
|
""" |
|
wav2vec-style transformer layer |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embedding_dim: float = 768, |
|
ffn_embedding_dim: float = 3072, |
|
num_attention_heads: float = 8, |
|
dropout: float = 0.1, |
|
attention_dropout: float = 0.1, |
|
activation_dropout: float = 0.1, |
|
activation_fn: str = "relu", |
|
layer_norm_first: bool = False, |
|
has_relative_attention_bias: bool = False, |
|
scaling_for_att: float = 1.0, |
|
) -> None: |
|
|
|
super().__init__() |
|
|
|
self.embedding_dim = embedding_dim |
|
self.dropout = dropout |
|
self.activation_dropout = activation_dropout |
|
|
|
|
|
self.activation_fn = get_activation_fn(activation_fn) |
|
self.self_attn = MultiheadAttention( |
|
self.embedding_dim, |
|
num_attention_heads, |
|
dropout=attention_dropout, |
|
self_attention=True, |
|
has_relative_attention_bias=has_relative_attention_bias, |
|
scaling_for_att=scaling_for_att |
|
) |
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(self.activation_dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
self.layer_norm_first = layer_norm_first |
|
|
|
|
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim) |
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) |
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) |
|
|
|
|
|
self.final_layer_norm = LayerNorm(self.embedding_dim) |
|
|
|
if has_relative_attention_bias: |
|
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
self_attn_mask: torch.Tensor = None, |
|
self_attn_padding_mask: torch.Tensor = None, |
|
need_weights: bool = False, |
|
att_args=None, |
|
pos_bias=None, |
|
): |
|
""" |
|
LayerNorm is applied either before or after the self-attention/ffn |
|
modules similar to the original Transformer imlementation. |
|
""" |
|
residual = x |
|
|
|
if self.layer_norm_first: |
|
x = self.self_attn_layer_norm(x) |
|
if pos_bias is not None: |
|
pos_bias = self.norm_k(pos_bias) |
|
x, attn = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=self_attn_padding_mask, |
|
attn_mask=self_attn_mask, |
|
position_bias=pos_bias, |
|
) |
|
x = self.dropout1(x) |
|
x = residual + x |
|
|
|
residual = x |
|
x = self.final_layer_norm(x) |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.dropout2(x) |
|
x = self.fc2(x) |
|
x = self.dropout3(x) |
|
x = residual + x |
|
else: |
|
x, attn = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=self_attn_padding_mask, |
|
position_bias=pos_bias, |
|
) |
|
|
|
x = self.dropout1(x) |
|
x = residual + x |
|
|
|
x = self.self_attn_layer_norm(x) |
|
|
|
residual = x |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.dropout2(x) |
|
x = self.fc2(x) |
|
x = self.dropout3(x) |
|
x = residual + x |
|
x = self.final_layer_norm(x) |
|
|
|
return x, attn |
|
|
|
|
|
class FairseqDropout(nn.Module): |
|
def __init__(self, p, module_name=None): |
|
super().__init__() |
|
self.p = p |
|
self.module_name = module_name |
|
self.apply_during_inference = False |
|
|
|
def forward(self, x, inplace: bool = False): |
|
if self.p > 0 and (self.training or self.apply_during_inference): |
|
return F.dropout(x, p=self.p, training=True, inplace=inplace) |
|
else: |
|
return x |
|
|
|
def make_generation_fast_( |
|
self, |
|
name: str, |
|
retain_dropout: bool = False, |
|
retain_dropout_modules: Optional[List[str]] = None, |
|
**kwargs |
|
): |
|
if retain_dropout: |
|
if retain_dropout_modules is not None and self.module_name is None: |
|
logger.warning( |
|
"Cannot enable dropout during inference for module {} " |
|
"because module_name was not set".format(name) |
|
) |
|
elif ( |
|
retain_dropout_modules is None |
|
or self.module_name in retain_dropout_modules |
|
): |
|
logger.info( |
|
"Enabling dropout during inference for module: {}".format(name) |
|
) |
|
self.apply_during_inference = True |
|
else: |
|
logger.info("Disabling dropout for module: {}".format(name)) |
|
|
|
|
|
class LearnedPositionalEmbedding(nn.Embedding): |
|
""" |
|
This module learns positional embeddings up to a fixed maximum size. |
|
Padding ids are ignored by either offsetting based on padding_idx |
|
or by setting padding_idx to None and ensuring that the appropriate |
|
position ids are passed to the forward function. |
|
""" |
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): |
|
super().__init__(num_embeddings, embedding_dim, padding_idx) |
|
self.onnx_trace = False |
|
if self.padding_idx is not None: |
|
self.max_positions = self.num_embeddings - self.padding_idx - 1 |
|
else: |
|
self.max_positions = self.num_embeddings |
|
|
|
def forward( |
|
self, |
|
input: Tensor, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
positions: Optional[Tensor] = None, |
|
): |
|
"""Input is expected to be of size [bsz x seqlen].""" |
|
assert (positions is None) or ( |
|
self.padding_idx is None |
|
), "If positions is pre-computed then padding_idx should not be set." |
|
|
|
if positions is None: |
|
if incremental_state is not None: |
|
|
|
|
|
positions = torch.zeros( |
|
(1, 1), device=input.device, dtype=input.dtype |
|
).fill_(int(self.padding_idx + input.size(1))) |
|
else: |
|
positions = utils_make_positions( |
|
input, self.padding_idx, onnx_trace=self.onnx_trace |
|
) |
|
positions = torch.clamp(positions, max=self.padding_idx + self.max_positions) |
|
return F.embedding( |
|
positions, |
|
self.weight, |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
|
|
|
|
class SinusoidalPositionalEmbedding(nn.Module): |
|
"""This module produces sinusoidal positional embeddings of any length. |
|
|
|
Padding symbols are ignored. |
|
""" |
|
|
|
def __init__(self, embedding_dim, padding_idx, init_size=1024): |
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
self.padding_idx = padding_idx if padding_idx is not None else 0 |
|
self.weights = SinusoidalPositionalEmbedding.get_embedding( |
|
init_size, embedding_dim, padding_idx |
|
) |
|
self.onnx_trace = False |
|
self.register_buffer("_float_tensor", torch.FloatTensor(1)) |
|
self.max_positions = int(1e5) |
|
|
|
def prepare_for_onnx_export_(self): |
|
self.onnx_trace = True |
|
|
|
@staticmethod |
|
def get_embedding( |
|
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None |
|
): |
|
"""Build sinusoidal embeddings. |
|
|
|
This matches the implementation in tensor2tensor, but differs slightly |
|
from the description in Section 3.5 of "Attention Is All You Need". |
|
""" |
|
half_dim = embedding_dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) |
|
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( |
|
1 |
|
) * emb.unsqueeze(0) |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( |
|
num_embeddings, -1 |
|
) |
|
if embedding_dim % 2 == 1: |
|
|
|
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) |
|
if padding_idx is not None: |
|
emb[padding_idx, :] = 0 |
|
return emb |
|
|
|
def forward( |
|
self, |
|
input, |
|
incremental_state: Optional[Any] = None, |
|
timestep: Optional[Tensor] = None, |
|
positions: Optional[Any] = None, |
|
): |
|
"""Input is expected to be of size [bsz x seqlen].""" |
|
bspair = torch.onnx.operators.shape_as_tensor(input) |
|
bsz, seq_len = bspair[0], bspair[1] |
|
max_pos = self.padding_idx + 1 + seq_len |
|
if self.weights is None or max_pos > self.weights.size(0): |
|
|
|
self.weights = SinusoidalPositionalEmbedding.get_embedding( |
|
max_pos, self.embedding_dim, self.padding_idx |
|
) |
|
self.weights = self.weights.to(self._float_tensor) |
|
|
|
if incremental_state is not None: |
|
|
|
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len |
|
if self.onnx_trace: |
|
return ( |
|
self.weights.index_select(index=self.padding_idx + pos, dim=0) |
|
.unsqueeze(1) |
|
.repeat(bsz, 1, 1) |
|
) |
|
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) |
|
|
|
positions = utils_make_positions( |
|
input, self.padding_idx, onnx_trace=self.onnx_trace |
|
) |
|
if self.onnx_trace: |
|
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) |
|
embedding_shape = torch.cat( |
|
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) |
|
) |
|
embeddings = torch.onnx.operators.reshape_from_tensor_shape( |
|
flat_embeddings, embedding_shape |
|
) |
|
return embeddings |
|
return ( |
|
self.weights.index_select(0, positions.view(-1)) |
|
.view(bsz, seq_len, -1) |
|
.detach() |
|
) |
|
|
|
|
|
try: |
|
from apex.normalization import FusedLayerNorm as _FusedLayerNorm |
|
|
|
has_fused_layernorm = True |
|
|
|
class FusedLayerNorm(_FusedLayerNorm): |
|
@torch.jit.unused |
|
def forward(self, x): |
|
if not x.is_cuda: |
|
return super().forward(x) |
|
else: |
|
with torch.cuda.device(x.device): |
|
return super().forward(x) |
|
|
|
except ImportError: |
|
has_fused_layernorm = False |
|
|
|
|
|
class Fp32LayerNorm(nn.LayerNorm): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, input): |
|
output = F.layer_norm( |
|
input.float(), |
|
self.normalized_shape, |
|
self.weight.float() if self.weight is not None else None, |
|
self.bias.float() if self.bias is not None else None, |
|
self.eps, |
|
) |
|
return output.type_as(input) |
|
|
|
|
|
class LayerDropModuleList(nn.ModuleList): |
|
""" |
|
A LayerDrop implementation based on :class:`torch.nn.ModuleList`. |
|
|
|
We refresh the choice of which layers to drop every time we iterate |
|
over the LayerDropModuleList instance. During evaluation we always |
|
iterate over all layers. |
|
|
|
Usage:: |
|
|
|
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) |
|
for layer in layers: # this might iterate over layers 1 and 3 |
|
x = layer(x) |
|
for layer in layers: # this might iterate over all layers |
|
x = layer(x) |
|
for layer in layers: # this might not iterate over any layers |
|
x = layer(x) |
|
|
|
Args: |
|
p (float): probability of dropping out each layer |
|
modules (iterable, optional): an iterable of modules to add |
|
""" |
|
|
|
def __init__(self, p, modules=None): |
|
super().__init__(modules) |
|
self.p = p |
|
|
|
def __iter__(self): |
|
dropout_probs = torch.empty(len(self)).uniform_() |
|
for i, m in enumerate(super().__iter__()): |
|
if not self.training or (dropout_probs[i] > self.p): |
|
yield m |
|
|
|
|
|
class RelativePositionalEncoding(torch.nn.Module): |
|
def __init__(self, d_model, maxlen=1000, embed_v=False): |
|
super(RelativePositionalEncoding, self).__init__() |
|
|
|
self.d_model = d_model |
|
self.maxlen = maxlen |
|
self.pe_k = torch.nn.Embedding(2*maxlen, d_model) |
|
if embed_v: |
|
self.pe_v = torch.nn.Embedding(2*maxlen, d_model) |
|
self.embed_v = embed_v |
|
|
|
|
|
def forward(self, pos_seq, incremental_state=None): |
|
pos_seq[pos_seq < -self.maxlen] = -self.maxlen |
|
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1 |
|
pos_seq = pos_seq + self.maxlen |
|
|
|
if incremental_state is not None: |
|
pos_seq = pos_seq[-1:] |
|
|
|
if self.embed_v: |
|
return self.pe_k(pos_seq), self.pe_v(pos_seq) |
|
else: |
|
return self.pe_k(pos_seq), None |
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
"""Multi-headed attention. |
|
|
|
See "Attention Is All You Need" for more details. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
kdim=None, |
|
vdim=None, |
|
dropout=0.0, |
|
bias=True, |
|
add_bias_kv=False, |
|
add_zero_attn=False, |
|
self_attention=False, |
|
encoder_decoder_attention=False, |
|
q_noise=0.0, |
|
qn_block_size=8, |
|
has_relative_attention_bias=False, |
|
scaling_for_att=1.0 |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.kdim = kdim if kdim is not None else embed_dim |
|
self.vdim = vdim if vdim is not None else embed_dim |
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
|
self.num_heads = num_heads |
|
self.dropout_module = FairseqDropout( |
|
dropout, module_name=self.__class__.__name__ |
|
) |
|
|
|
self.has_relative_attention_bias = has_relative_attention_bias |
|
|
|
self.head_dim = embed_dim // num_heads |
|
assert ( |
|
self.head_dim * num_heads == self.embed_dim |
|
), "embed_dim must be divisible by num_heads" |
|
self.scaling = self.head_dim ** -0.5 |
|
self.scaling_for_att = scaling_for_att |
|
|
|
self.self_attention = self_attention |
|
self.encoder_decoder_attention = encoder_decoder_attention |
|
|
|
assert not self.self_attention or self.qkv_same_dim, ( |
|
"Self-attention requires query, key and " "value to be of the same size" |
|
) |
|
|
|
self.k_proj = quant_noise( |
|
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size |
|
) |
|
self.v_proj = quant_noise( |
|
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size |
|
) |
|
self.q_proj = quant_noise( |
|
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size |
|
) |
|
|
|
self.out_proj = quant_noise( |
|
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size |
|
) |
|
|
|
if add_bias_kv: |
|
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) |
|
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) |
|
else: |
|
self.bias_k = self.bias_v = None |
|
|
|
self.add_zero_attn = add_zero_attn |
|
|
|
self.reset_parameters() |
|
|
|
self.onnx_trace = False |
|
|
|
def prepare_for_onnx_export_(self): |
|
self.onnx_trace = True |
|
|
|
def reset_parameters(self): |
|
if self.qkv_same_dim: |
|
|
|
|
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) |
|
else: |
|
nn.init.xavier_uniform_(self.k_proj.weight) |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
|
|
nn.init.xavier_uniform_(self.out_proj.weight) |
|
if self.out_proj.bias is not None: |
|
nn.init.constant_(self.out_proj.bias, 0.0) |
|
if self.bias_k is not None: |
|
nn.init.xavier_normal_(self.bias_k) |
|
if self.bias_v is not None: |
|
nn.init.xavier_normal_(self.bias_v) |
|
|
|
def forward( |
|
self, |
|
query, |
|
key: Optional[Tensor], |
|
value: Optional[Tensor], |
|
key_padding_mask: Optional[Tensor] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
need_weights: bool = True, |
|
static_kv: bool = False, |
|
attn_mask: Optional[Tensor] = None, |
|
before_softmax: bool = False, |
|
need_head_weights: bool = False, |
|
position_bias: Optional[Tensor] = None |
|
) -> Tuple[Tensor, Optional[Tensor]]: |
|
"""Input shape: Time x Batch x Channel |
|
|
|
Args: |
|
key_padding_mask (ByteTensor, optional): mask to exclude |
|
keys that are pads, of shape `(batch, src_len)`, where |
|
padding elements are indicated by 1s. |
|
need_weights (bool, optional): return the attention weights, |
|
averaged over heads (default: False). |
|
attn_mask (ByteTensor, optional): typically used to |
|
implement causal attention, where the mask prevents the |
|
attention from looking forward in time (default: None). |
|
before_softmax (bool, optional): return the raw attention |
|
weights and values before the attention softmax. |
|
need_head_weights (bool, optional): return the attention |
|
weights for each head. Implies *need_weights*. Default: |
|
return the average attention weights over all heads. |
|
""" |
|
if need_head_weights: |
|
need_weights = True |
|
|
|
is_tpu = query.device.type == "xla" |
|
|
|
tgt_len, bsz, embed_dim = query.size() |
|
src_len = tgt_len |
|
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" |
|
assert list(query.size()) == [tgt_len, bsz, embed_dim] |
|
if key is not None: |
|
src_len, key_bsz, _ = key.size() |
|
if not torch.jit.is_scripting(): |
|
assert key_bsz == bsz |
|
assert value is not None |
|
assert src_len, bsz == value.shape[:2] |
|
|
|
if ( |
|
not self.onnx_trace |
|
and not is_tpu |
|
and incremental_state is None |
|
and not static_kv |
|
|
|
|
|
and not torch.jit.is_scripting() |
|
and not self.has_relative_attention_bias |
|
): |
|
assert key is not None and value is not None |
|
return F.multi_head_attention_forward( |
|
query, |
|
key, |
|
value, |
|
self.embed_dim, |
|
self.num_heads, |
|
torch.empty([0]), |
|
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), |
|
self.bias_k, |
|
self.bias_v, |
|
self.add_zero_attn, |
|
self.dropout_module.p, |
|
self.out_proj.weight, |
|
self.out_proj.bias, |
|
self.training or self.dropout_module.apply_during_inference, |
|
key_padding_mask, |
|
need_weights, |
|
attn_mask, |
|
use_separate_proj_weight=True, |
|
q_proj_weight=self.q_proj.weight, |
|
k_proj_weight=self.k_proj.weight, |
|
v_proj_weight=self.v_proj.weight, |
|
) |
|
|
|
if incremental_state is not None: |
|
saved_state = self._get_input_buffer(incremental_state) |
|
if saved_state is not None and "prev_key" in saved_state: |
|
|
|
|
|
if static_kv: |
|
assert self.encoder_decoder_attention and not self.self_attention |
|
key = value = None |
|
else: |
|
saved_state = None |
|
|
|
if self.self_attention: |
|
q = self.q_proj(query) |
|
k = self.k_proj(query) |
|
v = self.v_proj(query) |
|
elif self.encoder_decoder_attention: |
|
|
|
q = self.q_proj(query) |
|
if key is None: |
|
assert value is None |
|
k = v = None |
|
else: |
|
k = self.k_proj(key) |
|
v = self.v_proj(key) |
|
|
|
else: |
|
assert key is not None and value is not None |
|
q = self.q_proj(query) |
|
k = self.k_proj(key) |
|
v = self.v_proj(value) |
|
q *= self.scaling |
|
q *= (1 / self.scaling_for_att) |
|
|
|
if self.bias_k is not None: |
|
assert self.bias_v is not None |
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
|
if attn_mask is not None: |
|
attn_mask = torch.cat( |
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
|
) |
|
if key_padding_mask is not None: |
|
key_padding_mask = torch.cat( |
|
[ |
|
key_padding_mask, |
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
|
], |
|
dim=1, |
|
) |
|
|
|
q = ( |
|
q.contiguous() |
|
.view(tgt_len, bsz * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
if k is not None: |
|
k = ( |
|
k.contiguous() |
|
.view(-1, bsz * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
if v is not None: |
|
v = ( |
|
v.contiguous() |
|
.view(-1, bsz * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
|
|
if saved_state is not None: |
|
|
|
if "prev_key" in saved_state: |
|
_prev_key = saved_state["prev_key"] |
|
assert _prev_key is not None |
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) |
|
if static_kv: |
|
k = prev_key |
|
else: |
|
assert k is not None |
|
k = torch.cat([prev_key, k], dim=1) |
|
src_len = k.size(1) |
|
if "prev_value" in saved_state: |
|
_prev_value = saved_state["prev_value"] |
|
assert _prev_value is not None |
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) |
|
if static_kv: |
|
v = prev_value |
|
else: |
|
assert v is not None |
|
v = torch.cat([prev_value, v], dim=1) |
|
prev_key_padding_mask: Optional[Tensor] = None |
|
if "prev_key_padding_mask" in saved_state: |
|
prev_key_padding_mask = saved_state["prev_key_padding_mask"] |
|
assert k is not None and v is not None |
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( |
|
key_padding_mask=key_padding_mask, |
|
prev_key_padding_mask=prev_key_padding_mask, |
|
batch_size=bsz, |
|
src_len=k.size(1), |
|
static_kv=static_kv, |
|
) |
|
|
|
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) |
|
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) |
|
saved_state["prev_key_padding_mask"] = key_padding_mask |
|
|
|
assert incremental_state is not None |
|
incremental_state = self._set_input_buffer(incremental_state, saved_state) |
|
assert k is not None |
|
assert k.size(1) == src_len |
|
|
|
|
|
|
|
if key_padding_mask is not None and key_padding_mask.dim() == 0: |
|
key_padding_mask = None |
|
|
|
if key_padding_mask is not None: |
|
assert key_padding_mask.size(0) == bsz |
|
assert key_padding_mask.size(1) == src_len |
|
|
|
if self.add_zero_attn: |
|
assert v is not None |
|
src_len += 1 |
|
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) |
|
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) |
|
if attn_mask is not None: |
|
attn_mask = torch.cat( |
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
|
) |
|
if key_padding_mask is not None: |
|
key_padding_mask = torch.cat( |
|
[ |
|
key_padding_mask, |
|
torch.zeros(key_padding_mask.size(0), 1).type_as( |
|
key_padding_mask |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2)) |
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) |
|
|
|
if position_bias is not None: |
|
|
|
|
|
reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) |
|
|
|
B = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) |
|
|
|
|
|
B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1)) |
|
|
|
attn_weights += B |
|
|
|
attn_weights *= self.scaling_for_att |
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
|
|
|
if attn_mask is not None: |
|
attn_mask = attn_mask.unsqueeze(0) |
|
if self.onnx_trace: |
|
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) |
|
attn_weights += attn_mask |
|
|
|
if key_padding_mask is not None: |
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
if not is_tpu: |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), |
|
float("-inf"), |
|
) |
|
else: |
|
attn_weights = attn_weights.transpose(0, 2) |
|
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) |
|
attn_weights = attn_weights.transpose(0, 2) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
if self.scaling_for_att > 1.0: |
|
attn_weights = attn_weights - attn_weights.detach().max(dim=-1, keepdim=True)[0] |
|
|
|
if before_softmax: |
|
return attn_weights, v |
|
|
|
attn_weights_float = softmax( |
|
attn_weights, dim=-1, onnx_trace=self.onnx_trace |
|
) |
|
attn_weights = attn_weights_float.type_as(attn_weights) |
|
attn_probs = self.dropout_module(attn_weights) |
|
|
|
assert v is not None |
|
attn = torch.bmm(attn_probs, v) |
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
|
if self.onnx_trace and attn.size(1) == 1: |
|
|
|
|
|
attn = attn.contiguous().view(tgt_len, bsz, embed_dim) |
|
else: |
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
|
attn = self.out_proj(attn) |
|
attn_weights: Optional[Tensor] = None |
|
if need_weights: |
|
attn_weights = attn_weights_float.view( |
|
bsz, self.num_heads, tgt_len, src_len |
|
).transpose(1, 0) |
|
if not need_head_weights: |
|
|
|
attn_weights = attn_weights.mean(dim=0) |
|
|
|
return attn, attn_weights |
|
|
|
@staticmethod |
|
def _append_prev_key_padding_mask( |
|
key_padding_mask: Optional[Tensor], |
|
prev_key_padding_mask: Optional[Tensor], |
|
batch_size: int, |
|
src_len: int, |
|
static_kv: bool, |
|
) -> Optional[Tensor]: |
|
|
|
if prev_key_padding_mask is not None and static_kv: |
|
new_key_padding_mask = prev_key_padding_mask |
|
elif prev_key_padding_mask is not None and key_padding_mask is not None: |
|
new_key_padding_mask = torch.cat( |
|
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 |
|
) |
|
|
|
|
|
|
|
elif prev_key_padding_mask is not None: |
|
if src_len > prev_key_padding_mask.size(1): |
|
filler = torch.zeros( |
|
(batch_size, src_len - prev_key_padding_mask.size(1)), |
|
device=prev_key_padding_mask.device, |
|
) |
|
new_key_padding_mask = torch.cat( |
|
[prev_key_padding_mask.float(), filler.float()], dim=1 |
|
) |
|
else: |
|
new_key_padding_mask = prev_key_padding_mask.float() |
|
elif key_padding_mask is not None: |
|
if src_len > key_padding_mask.size(1): |
|
filler = torch.zeros( |
|
(batch_size, src_len - key_padding_mask.size(1)), |
|
device=key_padding_mask.device, |
|
) |
|
new_key_padding_mask = torch.cat( |
|
[filler.float(), key_padding_mask.float()], dim=1 |
|
) |
|
else: |
|
new_key_padding_mask = key_padding_mask.float() |
|
else: |
|
new_key_padding_mask = prev_key_padding_mask |
|
return new_key_padding_mask |
|
|
|
@torch.jit.export |
|
def reorder_incremental_state( |
|
self, |
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], |
|
new_order: Tensor, |
|
): |
|
"""Reorder buffered internal state (for incremental generation).""" |
|
input_buffer = self._get_input_buffer(incremental_state) |
|
if input_buffer is not None: |
|
for k in input_buffer.keys(): |
|
input_buffer_k = input_buffer[k] |
|
if input_buffer_k is not None: |
|
if self.encoder_decoder_attention and input_buffer_k.size( |
|
0 |
|
) == new_order.size(0): |
|
break |
|
input_buffer[k] = input_buffer_k.index_select(0, new_order) |
|
incremental_state = self._set_input_buffer(incremental_state, input_buffer) |
|
return incremental_state |
|
|
|
def _get_input_buffer( |
|
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] |
|
) -> Dict[str, Optional[Tensor]]: |
|
result = self.get_incremental_state(incremental_state, "attn_state") |
|
if result is not None: |
|
return result |
|
else: |
|
empty_result: Dict[str, Optional[Tensor]] = {} |
|
return empty_result |
|
|
|
def _set_input_buffer( |
|
self, |
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], |
|
buffer: Dict[str, Optional[Tensor]], |
|
): |
|
return self.set_incremental_state(incremental_state, "attn_state", buffer) |
|
|
|
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): |
|
return attn_weights |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
prefix = name + "." if name != "" else "" |
|
items_to_add = {} |
|
keys_to_remove = [] |
|
for k in state_dict.keys(): |
|
if k.endswith(prefix + "in_proj_weight"): |
|
|
|
dim = int(state_dict[k].shape[0] / 3) |
|
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] |
|
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] |
|
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] |
|
|
|
keys_to_remove.append(k) |
|
|
|
k_bias = prefix + "in_proj_bias" |
|
if k_bias in state_dict.keys(): |
|
dim = int(state_dict[k].shape[0] / 3) |
|
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] |
|
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ |
|
dim : 2 * dim |
|
] |
|
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] |
|
|
|
keys_to_remove.append(prefix + "in_proj_bias") |
|
|
|
for k in keys_to_remove: |
|
del state_dict[k] |
|
|
|
for key, value in items_to_add.items(): |
|
state_dict[key] = value |
|
|
|
|
|
class ConvFeatureExtractionModel(nn.Module): |
|
def __init__( |
|
self, |
|
conv_layers: List[Tuple[int, int, int]], |
|
dropout: float = 0.0, |
|
mode: str = "default", |
|
conv_bias: bool = False, |
|
): |
|
super().__init__() |
|
|
|
assert mode in {"default", "layer_norm"} |
|
|
|
def block( |
|
n_in, |
|
n_out, |
|
k, |
|
stride, |
|
is_layer_norm=False, |
|
is_group_norm=False, |
|
conv_bias=False, |
|
): |
|
def make_conv(): |
|
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
|
nn.init.kaiming_normal_(conv.weight) |
|
return conv |
|
|
|
assert ( |
|
is_layer_norm and is_group_norm |
|
) == False, "layer norm and group norm are exclusive" |
|
|
|
if is_layer_norm: |
|
return nn.Sequential( |
|
make_conv(), |
|
nn.Dropout(p=dropout), |
|
nn.Sequential( |
|
TransposeLast(), |
|
Fp32LayerNorm(dim, elementwise_affine=True), |
|
TransposeLast(), |
|
), |
|
nn.GELU(), |
|
) |
|
elif is_group_norm: |
|
return nn.Sequential( |
|
make_conv(), |
|
nn.Dropout(p=dropout), |
|
Fp32GroupNorm(dim, dim, affine=True), |
|
nn.GELU(), |
|
) |
|
else: |
|
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
|
|
|
in_d = 1 |
|
self.conv_layers = nn.ModuleList() |
|
for i, cl in enumerate(conv_layers): |
|
assert len(cl) == 3, "invalid conv definition: " + str(cl) |
|
(dim, k, stride) = cl |
|
|
|
self.conv_layers.append( |
|
block( |
|
in_d, |
|
dim, |
|
k, |
|
stride, |
|
is_layer_norm=mode == "layer_norm", |
|
is_group_norm=mode == "default" and i == 0, |
|
conv_bias=conv_bias, |
|
) |
|
) |
|
in_d = dim |
|
|
|
def forward(self, x): |
|
|
|
|
|
x = x.unsqueeze(1) |
|
|
|
for conv in self.conv_layers: |
|
x = conv(x) |
|
|
|
return x |
|
|
|
|
|
class TransposeLast(nn.Module): |
|
def __init__(self, deconstruct_idx=None): |
|
super().__init__() |
|
self.deconstruct_idx = deconstruct_idx |
|
|
|
def forward(self, x): |
|
if self.deconstruct_idx is not None: |
|
x = x[self.deconstruct_idx] |
|
return x.transpose(-2, -1) |
|
|
|
|
|
class Fp32GroupNorm(nn.GroupNorm): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, input): |
|
output = F.group_norm( |
|
input.float(), |
|
self.num_groups, |
|
self.weight.float() if self.weight is not None else None, |
|
self.bias.float() if self.bias is not None else None, |
|
self.eps, |
|
) |
|
return output.type_as(input) |
|
|
|
|
|
class GradMultiply(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x, scale): |
|
ctx.scale = scale |
|
res = x.new(x) |
|
return res |
|
|
|
@staticmethod |
|
def backward(ctx, grad): |
|
return grad * ctx.scale, None |
|
|
|
|
|
class Rotate3D(nn.Module): |
|
""" |
|
(T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D) |
|
""" |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return x.permute(1, 2, 0) |
|
|
|
|
|
class SamePad(nn.Module): |
|
def __init__(self, kernel_size, causal=False): |
|
super().__init__() |
|
if causal: |
|
self.remove = kernel_size - 1 |
|
else: |
|
self.remove = 1 if kernel_size % 2 == 0 else 0 |
|
|
|
def forward(self, x): |
|
if self.remove > 0: |
|
x = x[:, :, : -self.remove] |
|
return x |
|
|