|
"""Basic model. Predicts tags for every token""" |
|
from typing import Dict, Optional, List, Any |
|
|
|
import numpy |
|
import torch |
|
import torch.nn.functional as F |
|
from allennlp.data import Vocabulary |
|
from allennlp.models.model import Model |
|
from allennlp.modules import TimeDistributed, TextFieldEmbedder |
|
from allennlp.nn import InitializerApplicator, RegularizerApplicator |
|
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits |
|
from allennlp.training.metrics import CategoricalAccuracy |
|
from overrides import overrides |
|
from torch.nn.modules.linear import Linear |
|
|
|
|
|
@Model.register("seq2labels") |
|
class Seq2Labels(Model): |
|
""" |
|
This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then |
|
predicts a tag (or couple tags) for each token in the sequence. |
|
|
|
Parameters |
|
---------- |
|
vocab : ``Vocabulary``, required |
|
A Vocabulary, required in order to compute sizes for input/output projections. |
|
text_field_embedder : ``TextFieldEmbedder``, required |
|
Used to embed the ``tokens`` ``TextField`` we get as input to the model. |
|
encoder : ``Seq2SeqEncoder`` |
|
The encoder (with its own internal stacking) that we will use in between embedding tokens |
|
and predicting output tags. |
|
calculate_span_f1 : ``bool``, optional (default=``None``) |
|
Calculate span-level F1 metrics during training. If this is ``True``, then |
|
``label_encoding`` is required. If ``None`` and |
|
label_encoding is specified, this is set to ``True``. |
|
If ``None`` and label_encoding is not specified, it defaults |
|
to ``False``. |
|
label_encoding : ``str``, optional (default=``None``) |
|
Label encoding to use when calculating span f1. |
|
Valid options are "BIO", "BIOUL", "IOB1", "BMES". |
|
Required if ``calculate_span_f1`` is true. |
|
labels_namespace : ``str``, optional (default=``labels``) |
|
This is needed to compute the SpanBasedF1Measure metric, if desired. |
|
Unless you did something unusual, the default value should be what you want. |
|
verbose_metrics : ``bool``, optional (default = False) |
|
If true, metrics will be returned per label class in addition |
|
to the overall statistics. |
|
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) |
|
Used to initialize the model parameters. |
|
regularizer : ``RegularizerApplicator``, optional (default=``None``) |
|
If provided, will be used to calculate the regularization penalty during training. |
|
""" |
|
|
|
def __init__(self, vocab: Vocabulary, |
|
text_field_embedder: TextFieldEmbedder, |
|
predictor_dropout=0.0, |
|
labels_namespace: str = "labels", |
|
detect_namespace: str = "d_tags", |
|
verbose_metrics: bool = False, |
|
label_smoothing: float = 0.0, |
|
confidence: float = 0.0, |
|
del_confidence: float = 0.0, |
|
initializer: InitializerApplicator = InitializerApplicator(), |
|
regularizer: Optional[RegularizerApplicator] = None) -> None: |
|
super(Seq2Labels, self).__init__(vocab, regularizer) |
|
|
|
self.label_namespaces = [labels_namespace, |
|
detect_namespace] |
|
self.text_field_embedder = text_field_embedder |
|
self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace) |
|
self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace) |
|
self.label_smoothing = label_smoothing |
|
self.confidence = confidence |
|
self.del_conf = del_confidence |
|
self.incorr_index = self.vocab.get_token_index("INCORRECT", |
|
namespace=detect_namespace) |
|
|
|
self._verbose_metrics = verbose_metrics |
|
self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout)) |
|
|
|
self.tag_labels_projection_layer = TimeDistributed( |
|
Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes)) |
|
|
|
self.tag_detect_projection_layer = TimeDistributed( |
|
Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes)) |
|
|
|
self.metrics = {"accuracy": CategoricalAccuracy()} |
|
|
|
initializer(self) |
|
|
|
@overrides |
|
def forward(self, |
|
tokens: Dict[str, torch.LongTensor], |
|
labels: torch.LongTensor = None, |
|
d_tags: torch.LongTensor = None, |
|
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: |
|
|
|
""" |
|
Parameters |
|
---------- |
|
tokens : Dict[str, torch.LongTensor], required |
|
The output of ``TextField.as_array()``, which should typically be passed directly to a |
|
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` |
|
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": |
|
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used |
|
for the ``TokenIndexers`` when you created the ``TextField`` representing your |
|
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, |
|
which knows how to combine different word representations into a single vector per |
|
token in your input. |
|
labels : torch.LongTensor, optional (default = None) |
|
A torch tensor representing the sequence of integer gold class labels of shape |
|
``(batch_size, num_tokens)``. |
|
d_tags : torch.LongTensor, optional (default = None) |
|
A torch tensor representing the sequence of integer gold class labels of shape |
|
``(batch_size, num_tokens)``. |
|
metadata : ``List[Dict[str, Any]]``, optional, (default = None) |
|
metadata containing the original words in the sentence to be tagged under a 'words' key. |
|
|
|
Returns |
|
------- |
|
An output dictionary consisting of: |
|
logits : torch.FloatTensor |
|
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing |
|
unnormalised log probabilities of the tag classes. |
|
class_probabilities : torch.FloatTensor |
|
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing |
|
a distribution of the tag classes per word. |
|
loss : torch.FloatTensor, optional |
|
A scalar loss to be optimised. |
|
|
|
""" |
|
encoded_text = self.text_field_embedder(tokens) |
|
batch_size, sequence_length, _ = encoded_text.size() |
|
mask = get_text_field_mask(tokens) |
|
logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text)) |
|
logits_d = self.tag_detect_projection_layer(encoded_text) |
|
|
|
class_probabilities_labels = F.softmax(logits_labels, dim=-1).view( |
|
[batch_size, sequence_length, self.num_labels_classes]) |
|
class_probabilities_d = F.softmax(logits_d, dim=-1).view( |
|
[batch_size, sequence_length, self.num_detect_classes]) |
|
error_probs = class_probabilities_d[:, :, self.incorr_index] * mask |
|
incorr_prob = torch.max(error_probs, dim=-1)[0] |
|
|
|
probability_change = [self.confidence, self.del_conf] + [0] * (self.num_labels_classes - 2) |
|
class_probabilities_labels += torch.FloatTensor(probability_change).repeat( |
|
(batch_size, sequence_length, 1)).to(class_probabilities_labels.device) |
|
|
|
output_dict = {"logits_labels": logits_labels, |
|
"logits_d_tags": logits_d, |
|
"class_probabilities_labels": class_probabilities_labels, |
|
"class_probabilities_d_tags": class_probabilities_d, |
|
"max_error_probability": incorr_prob} |
|
if labels is not None and d_tags is not None: |
|
loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask, |
|
label_smoothing=self.label_smoothing) |
|
loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask) |
|
for metric in self.metrics.values(): |
|
metric(logits_labels, labels, mask.float()) |
|
metric(logits_d, d_tags, mask.float()) |
|
output_dict["loss"] = loss_labels + loss_d |
|
|
|
if metadata is not None: |
|
output_dict["words"] = [x["words"] for x in metadata] |
|
return output_dict |
|
|
|
@overrides |
|
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
""" |
|
Does a simple position-wise argmax over each token, converts indices to string labels, and |
|
adds a ``"tags"`` key to the dictionary with the result. |
|
""" |
|
for label_namespace in self.label_namespaces: |
|
all_predictions = output_dict[f'class_probabilities_{label_namespace}'] |
|
all_predictions = all_predictions.cpu().data.numpy() |
|
if all_predictions.ndim == 3: |
|
predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] |
|
else: |
|
predictions_list = [all_predictions] |
|
all_tags = [] |
|
|
|
for predictions in predictions_list: |
|
argmax_indices = numpy.argmax(predictions, axis=-1) |
|
tags = [self.vocab.get_token_from_index(x, namespace=label_namespace) |
|
for x in argmax_indices] |
|
all_tags.append(tags) |
|
output_dict[f'{label_namespace}'] = all_tags |
|
return output_dict |
|
|
|
@overrides |
|
def get_metrics(self, reset: bool = False) -> Dict[str, float]: |
|
metrics_to_return = {metric_name: metric.get_metric(reset) for |
|
metric_name, metric in self.metrics.items()} |
|
return metrics_to_return |
|
|