File size: 6,178 Bytes
78a5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from transformers import BertModel
import torch
import onnx
import pytorch_lightning as pl
import wandb
from metrics import MyAccuracy
from utils import num_unique_labels
from typing import Dict, Tuple, List, Optional

class MultiTaskBertModel(pl.LightningModule):

    """
    Multi-task Bert model for Named Entity Recognition (NER) and Intent Classification

    Args:
        config (BertConfig): Bert model configuration.
        dataset (Dict[str, Union[str, List[str]]]): A dictionary containing keys 'text', 'ner', and 'intent'.
    """

    def __init__(self, config, dataset):
        super().__init__()

        self.num_ner_labels, self.num_intent_labels = num_unique_labels(dataset)

        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

        self.model = BertModel(config=config)

        self.ner_classifier = torch.nn.Linear(config.hidden_size, self.num_ner_labels)
        self.intent_classifier = torch.nn.Linear(config.hidden_size, self.num_intent_labels)

        # log hyperparameters
        self.save_hyperparameters()

        self.accuracy = MyAccuracy()

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:

        """
        Perform a forward pass through Multi-task Bert model.

        Args:
            input_ids (torch.Tensor, torch.shape: (batch, length_of_tokenized_sequences)): Input token IDs.
            attention_mask (Optional[torch.Tensor]): Attention mask for input tokens.

        Returns:
            Tuple[torch.Tensor,torch.Tensor]: NER logits, Intent logits.
        """

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        ner_logits = self.ner_classifier(sequence_output)

        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        intent_logits = self.intent_classifier(pooled_output)

        return ner_logits, intent_logits

    def training_step(self: pl.LightningModule, batch, batch_idx: int) -> torch.Tensor:
        """
        Perform a training step for the Multi-task BERT model.

        Args:
            batch: Input batch.
            batch_idx (int): Index of the batch.

        Returns:
            torch.Tensor: Loss value
        """
        loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
        accuracy_ner = self.accuracy(ner_logits, ner_labels, self.num_ner_labels)
        accuracy_intent = self.accuracy(intent_logits, intent_labels, self.num_intent_labels)
        self.log_dict({'training_loss': loss, 'ner_accuracy': accuracy_ner, 'intent_accuracy': accuracy_intent},
                      on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def on_validation_epoch_start(self):
        self.validation_step_outputs_ner = []
        self.validation_step_outputs_intent = []

    def validation_step(self, batch, batch_idx: int) -> torch.Tensor:
        """
        Perform a validation step for the Multi-task BERT model.

        Args:
            batch: Input batch.
            batch_idx (int): Index of the batch.

        Returns:
            torch.Tensor: Loss value.
        """
        loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
        # self.log('val_loss', loss)
        accuracy_ner = self.accuracy(ner_logits, ner_labels, self.num_ner_labels)
        accuracy_intent = self.accuracy(intent_logits, intent_labels, self.num_intent_labels)
        self.log_dict({'validation_loss': loss, 'val_ner_accuracy': accuracy_ner, 'val_intent_accuracy': accuracy_intent},
                      on_step=False, on_epoch=True, prog_bar=True)

        self.validation_step_outputs_ner.append(ner_logits)
        self.validation_step_outputs_intent.append(intent_logits)
        return loss

    def on_validation_epoch_end(self):
        """
        Perform actions at the end of validation epoch to track the training process in WandB.
        """
        validation_step_outputs_ner = self.validation_step_outputs_ner
        validation_step_outputs_intent = self.validation_step_outputs_intent

        dummy_input = torch.zeros((1, 128), device=self.device, dtype=torch.long)
        model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
        torch.onnx.export(self, dummy_input, model_filename)
        artifact = wandb.Artifact(name="model.ckpt", type="model")
        artifact.add_file(model_filename)
        self.logger.experiment.log_artifact(artifact)

        flattened_logits_ner = torch.flatten(torch.cat(validation_step_outputs_ner))
        flattened_logits_intent = torch.flatten(torch.cat(validation_step_outputs_intent))
        self.logger.experiment.log(
            {"valid/ner_logits": wandb.Histogram(flattened_logits_ner.to('cpu')),
             "valid/intent_logits": wandb.Histogram(flattened_logits_intent.to('cpu')),
             "global_step": self.global_step}
        )

    def _common_step(self, batch, batch_idx):
        """
        Common steps for both training and validation. Calculate loss for both NER and intent layer.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            Combiner loss value, NER logits, intent logits, NER labels, intent labels.
        """
        ids = batch['input_ids']
        mask = batch['attention_mask']
        ner_labels = batch['ner_labels']
        intent_labels = batch['intent_labels']

        ner_logits, intent_logits = self.forward(input_ids=ids, attention_mask=mask)

        criterion = torch.nn.CrossEntropyLoss()

        ner_loss = criterion(ner_logits.view(-1, self.num_ner_labels), ner_labels.view(-1).long())
        intent_loss = criterion(intent_logits.view(-1, self.num_intent_labels), intent_labels.view(-1).long())

        loss = ner_loss + intent_loss
        return loss, ner_logits, intent_logits, ner_labels, intent_labels

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer