File size: 6,178 Bytes
78a5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from transformers import BertModel
import torch
import onnx
import pytorch_lightning as pl
import wandb
from metrics import MyAccuracy
from utils import num_unique_labels
from typing import Dict, Tuple, List, Optional
class MultiTaskBertModel(pl.LightningModule):
"""
Multi-task Bert model for Named Entity Recognition (NER) and Intent Classification
Args:
config (BertConfig): Bert model configuration.
dataset (Dict[str, Union[str, List[str]]]): A dictionary containing keys 'text', 'ner', and 'intent'.
"""
def __init__(self, config, dataset):
super().__init__()
self.num_ner_labels, self.num_intent_labels = num_unique_labels(dataset)
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
self.model = BertModel(config=config)
self.ner_classifier = torch.nn.Linear(config.hidden_size, self.num_ner_labels)
self.intent_classifier = torch.nn.Linear(config.hidden_size, self.num_intent_labels)
# log hyperparameters
self.save_hyperparameters()
self.accuracy = MyAccuracy()
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform a forward pass through Multi-task Bert model.
Args:
input_ids (torch.Tensor, torch.shape: (batch, length_of_tokenized_sequences)): Input token IDs.
attention_mask (Optional[torch.Tensor]): Attention mask for input tokens.
Returns:
Tuple[torch.Tensor,torch.Tensor]: NER logits, Intent logits.
"""
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
ner_logits = self.ner_classifier(sequence_output)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
intent_logits = self.intent_classifier(pooled_output)
return ner_logits, intent_logits
def training_step(self: pl.LightningModule, batch, batch_idx: int) -> torch.Tensor:
"""
Perform a training step for the Multi-task BERT model.
Args:
batch: Input batch.
batch_idx (int): Index of the batch.
Returns:
torch.Tensor: Loss value
"""
loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
accuracy_ner = self.accuracy(ner_logits, ner_labels, self.num_ner_labels)
accuracy_intent = self.accuracy(intent_logits, intent_labels, self.num_intent_labels)
self.log_dict({'training_loss': loss, 'ner_accuracy': accuracy_ner, 'intent_accuracy': accuracy_intent},
on_step=False, on_epoch=True, prog_bar=True)
return loss
def on_validation_epoch_start(self):
self.validation_step_outputs_ner = []
self.validation_step_outputs_intent = []
def validation_step(self, batch, batch_idx: int) -> torch.Tensor:
"""
Perform a validation step for the Multi-task BERT model.
Args:
batch: Input batch.
batch_idx (int): Index of the batch.
Returns:
torch.Tensor: Loss value.
"""
loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
# self.log('val_loss', loss)
accuracy_ner = self.accuracy(ner_logits, ner_labels, self.num_ner_labels)
accuracy_intent = self.accuracy(intent_logits, intent_labels, self.num_intent_labels)
self.log_dict({'validation_loss': loss, 'val_ner_accuracy': accuracy_ner, 'val_intent_accuracy': accuracy_intent},
on_step=False, on_epoch=True, prog_bar=True)
self.validation_step_outputs_ner.append(ner_logits)
self.validation_step_outputs_intent.append(intent_logits)
return loss
def on_validation_epoch_end(self):
"""
Perform actions at the end of validation epoch to track the training process in WandB.
"""
validation_step_outputs_ner = self.validation_step_outputs_ner
validation_step_outputs_intent = self.validation_step_outputs_intent
dummy_input = torch.zeros((1, 128), device=self.device, dtype=torch.long)
model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
torch.onnx.export(self, dummy_input, model_filename)
artifact = wandb.Artifact(name="model.ckpt", type="model")
artifact.add_file(model_filename)
self.logger.experiment.log_artifact(artifact)
flattened_logits_ner = torch.flatten(torch.cat(validation_step_outputs_ner))
flattened_logits_intent = torch.flatten(torch.cat(validation_step_outputs_intent))
self.logger.experiment.log(
{"valid/ner_logits": wandb.Histogram(flattened_logits_ner.to('cpu')),
"valid/intent_logits": wandb.Histogram(flattened_logits_intent.to('cpu')),
"global_step": self.global_step}
)
def _common_step(self, batch, batch_idx):
"""
Common steps for both training and validation. Calculate loss for both NER and intent layer.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Combiner loss value, NER logits, intent logits, NER labels, intent labels.
"""
ids = batch['input_ids']
mask = batch['attention_mask']
ner_labels = batch['ner_labels']
intent_labels = batch['intent_labels']
ner_logits, intent_logits = self.forward(input_ids=ids, attention_mask=mask)
criterion = torch.nn.CrossEntropyLoss()
ner_loss = criterion(ner_logits.view(-1, self.num_ner_labels), ner_labels.view(-1).long())
intent_loss = criterion(intent_logits.view(-1, self.num_intent_labels), intent_labels.view(-1).long())
loss = ner_loss + intent_loss
return loss, ner_logits, intent_logits, ner_labels, intent_labels
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
return optimizer |