Spaces:
Running
Running
| """Tweaked version of corresponding AllenNLP file""" | |
| import logging | |
| from copy import deepcopy | |
| from typing import Dict | |
| import torch | |
| import torch.nn.functional as F | |
| from allennlp.modules.token_embedders.token_embedder import TokenEmbedder | |
| from allennlp.nn import util | |
| from transformers import AutoModel, PreTrainedModel | |
| logger = logging.getLogger(__name__) | |
| class PretrainedBertModel: | |
| """ | |
| In some instances you may want to load the same BERT model twice | |
| (e.g. to use as a token embedder and also as a pooling layer). | |
| This factory provides a cache so that you don't actually have to load the model twice. | |
| """ | |
| _cache: Dict[str, PreTrainedModel] = {} | |
| def load(cls, model_name: str, cache_model: bool = True) -> PreTrainedModel: | |
| if model_name in cls._cache: | |
| return PretrainedBertModel._cache[model_name] | |
| model = AutoModel.from_pretrained(model_name) | |
| if cache_model: | |
| cls._cache[model_name] = model | |
| return model | |
| class BertEmbedder(TokenEmbedder): | |
| """ | |
| A ``TokenEmbedder`` that produces BERT embeddings for your tokens. | |
| Should be paired with a ``BertIndexer``, which produces wordpiece ids. | |
| Most likely you probably want to use ``PretrainedBertEmbedder`` | |
| for one of the named pretrained models, not this base class. | |
| Parameters | |
| ---------- | |
| bert_model: ``BertModel`` | |
| The BERT model being wrapped. | |
| top_layer_only: ``bool``, optional (default = ``False``) | |
| If ``True``, then only return the top layer instead of apply the scalar mix. | |
| max_pieces : int, optional (default: 512) | |
| The BERT embedder uses positional embeddings and so has a corresponding | |
| maximum length for its input ids. Assuming the inputs are windowed | |
| and padded appropriately by this length, the embedder will split them into a | |
| large batch, feed them into BERT, and recombine the output as if it was a | |
| longer sequence. | |
| num_start_tokens : int, optional (default: 1) | |
| The number of starting special tokens input to BERT (usually 1, i.e., [CLS]) | |
| num_end_tokens : int, optional (default: 1) | |
| The number of ending tokens input to BERT (usually 1, i.e., [SEP]) | |
| scalar_mix_parameters: ``List[float]``, optional, (default = None) | |
| If not ``None``, use these scalar mix parameters to weight the representations | |
| produced by different layers. These mixing weights are not updated during | |
| training. | |
| """ | |
| def __init__( | |
| self, | |
| bert_model: PreTrainedModel, | |
| top_layer_only: bool = False, | |
| max_pieces: int = 512, | |
| num_start_tokens: int = 1, | |
| num_end_tokens: int = 1 | |
| ) -> None: | |
| super().__init__() | |
| self.bert_model = deepcopy(bert_model) | |
| self.output_dim = bert_model.config.hidden_size | |
| self.max_pieces = max_pieces | |
| self.num_start_tokens = num_start_tokens | |
| self.num_end_tokens = num_end_tokens | |
| self._scalar_mix = None | |
| def set_weights(self, freeze): | |
| for param in self.bert_model.parameters(): | |
| param.requires_grad = not freeze | |
| return | |
| def get_output_dim(self) -> int: | |
| return self.output_dim | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| offsets: torch.LongTensor = None | |
| ) -> torch.Tensor: | |
| """ | |
| Parameters | |
| ---------- | |
| input_ids : ``torch.LongTensor`` | |
| The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. | |
| offsets : ``torch.LongTensor``, optional | |
| The BERT embeddings are one per wordpiece. However it's possible/likely | |
| you might want one per original token. In that case, ``offsets`` | |
| represents the indices of the desired wordpiece for each original token. | |
| Depending on how your token indexer is configured, this could be the | |
| position of the last wordpiece for each token, or it could be the position | |
| of the first wordpiece for each token. | |
| For example, if you had the sentence "Definitely not", and if the corresponding | |
| wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids | |
| would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. | |
| If offsets are provided, the returned tensor will contain only the wordpiece | |
| embeddings at those positions, and (in particular) will contain one embedding | |
| per token. If offsets are not provided, the entire tensor of wordpiece embeddings | |
| will be returned. | |
| """ | |
| batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) | |
| initial_dims = list(input_ids.shape[:-1]) | |
| # The embedder may receive an input tensor that has a sequence length longer than can | |
| # be fit. In that case, we should expect the wordpiece indexer to create padded windows | |
| # of length `self.max_pieces` for us, and have them concatenated into one long sequence. | |
| # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." | |
| # We can then split the sequence into sub-sequences of that length, and concatenate them | |
| # along the batch dimension so we effectively have one huge batch of partial sentences. | |
| # This can then be fed into BERT without any sentence length issues. Keep in mind | |
| # that the memory consumption can dramatically increase for large batches with extremely | |
| # long sentences. | |
| needs_split = full_seq_len > self.max_pieces | |
| last_window_size = 0 | |
| if needs_split: | |
| # Split the flattened list by the window size, `max_pieces` | |
| split_input_ids = list(input_ids.split(self.max_pieces, dim=-1)) | |
| # We want all sequences to be the same length, so pad the last sequence | |
| last_window_size = split_input_ids[-1].size(-1) | |
| padding_amount = self.max_pieces - last_window_size | |
| split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0) | |
| # Now combine the sequences along the batch dimension | |
| input_ids = torch.cat(split_input_ids, dim=0) | |
| input_mask = (input_ids != 0).long() | |
| # input_ids may have extra dimensions, so we reshape down to 2-d | |
| # before calling the BERT model and then reshape back at the end. | |
| all_encoder_layers = self.bert_model( | |
| input_ids=util.combine_initial_dims(input_ids), | |
| attention_mask=util.combine_initial_dims(input_mask), | |
| )[0] | |
| if len(all_encoder_layers[0].shape) == 3: | |
| all_encoder_layers = torch.stack(all_encoder_layers) | |
| elif len(all_encoder_layers[0].shape) == 2: | |
| all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0) | |
| if needs_split: | |
| # First, unpack the output embeddings into one long sequence again | |
| unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) | |
| unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) | |
| # Next, select indices of the sequence such that it will result in embeddings representing the original | |
| # sentence. To capture maximal context, the indices will be the middle part of each embedded window | |
| # sub-sequence (plus any leftover start and final edge windows), e.g., | |
| # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | |
| # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" | |
| # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start | |
| # and final windows with indices [0, 1] and [14, 15] respectively. | |
| # Find the stride as half the max pieces, ignoring the special start and end tokens | |
| # Calculate an offset to extract the centermost embeddings of each window | |
| stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2 | |
| stride_offset = stride // 2 + self.num_start_tokens | |
| first_window = list(range(stride_offset)) | |
| max_context_windows = [ | |
| i | |
| for i in range(full_seq_len) | |
| if stride_offset - 1 < i % self.max_pieces < stride_offset + stride | |
| ] | |
| # Lookback what's left, unless it's the whole self.max_pieces window | |
| if full_seq_len % self.max_pieces == 0: | |
| lookback = self.max_pieces | |
| else: | |
| lookback = full_seq_len % self.max_pieces | |
| final_window_start = full_seq_len - lookback + stride_offset + stride | |
| final_window = list(range(final_window_start, full_seq_len)) | |
| select_indices = first_window + max_context_windows + final_window | |
| initial_dims.append(len(select_indices)) | |
| recombined_embeddings = unpacked_embeddings[:, :, select_indices] | |
| else: | |
| recombined_embeddings = all_encoder_layers | |
| # Recombine the outputs of all layers | |
| # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) | |
| # recombined = torch.cat(combined, dim=2) | |
| input_mask = (recombined_embeddings != 0).long() | |
| if self._scalar_mix is not None: | |
| mix = self._scalar_mix(recombined_embeddings, input_mask) | |
| else: | |
| mix = recombined_embeddings[-1] | |
| # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) | |
| if offsets is None: | |
| # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) | |
| dims = initial_dims if needs_split else input_ids.size() | |
| return util.uncombine_initial_dims(mix, dims) | |
| else: | |
| # offsets is (batch_size, d1, ..., dn, orig_sequence_length) | |
| offsets2d = util.combine_initial_dims(offsets) | |
| # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) | |
| range_vector = util.get_range_vector( | |
| offsets2d.size(0), device=util.get_device_of(mix) | |
| ).unsqueeze(1) | |
| # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) | |
| selected_embeddings = mix[range_vector, offsets2d] | |
| return util.uncombine_initial_dims(selected_embeddings, offsets.size()) | |
| # @TokenEmbedder.register("bert-pretrained") | |
| class PretrainedBertEmbedder(BertEmbedder): | |
| """ | |
| Parameters | |
| ---------- | |
| pretrained_model: ``str`` | |
| Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), | |
| or the path to the .tar.gz file with the model weights. | |
| If the name is a key in the list of pretrained models at | |
| https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41 | |
| the corresponding path will be used; otherwise it will be interpreted as a path or URL. | |
| requires_grad : ``bool``, optional (default = False) | |
| If True, compute gradient of BERT parameters for fine tuning. | |
| top_layer_only: ``bool``, optional (default = ``False``) | |
| If ``True``, then only return the top layer instead of apply the scalar mix. | |
| scalar_mix_parameters: ``List[float]``, optional, (default = None) | |
| If not ``None``, use these scalar mix parameters to weight the representations | |
| produced by different layers. These mixing weights are not updated during | |
| training. | |
| """ | |
| def __init__( | |
| self, | |
| pretrained_model: str, | |
| requires_grad: bool = False, | |
| top_layer_only: bool = False, | |
| special_tokens_fix: int = 0, | |
| ) -> None: | |
| model = PretrainedBertModel.load(pretrained_model) | |
| for param in model.parameters(): | |
| param.requires_grad = requires_grad | |
| super().__init__( | |
| bert_model=model, | |
| top_layer_only=top_layer_only | |
| ) | |
| if special_tokens_fix: | |
| try: | |
| vocab_size = self.bert_model.embeddings.word_embeddings.num_embeddings | |
| except AttributeError: | |
| # reserve more space | |
| vocab_size = self.bert_model.word_embedding.num_embeddings + 5 | |
| self.bert_model.resize_token_embeddings(vocab_size + 1) | |