kowalsky commited on
Commit
78a5823
·
verified ·
1 Parent(s): 8436c02

Upload 18 files

Browse files
config.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "return_dict": true,
3
+ "output_hidden_states": false,
4
+ "output_attentions": false,
5
+ "torchscript": false,
6
+ "torch_dtype": null,
7
+ "use_bfloat16": false,
8
+ "tf_legacy_loss": false,
9
+ "pruned_heads": {},
10
+ "tie_word_embeddings": true,
11
+ "is_encoder_decoder": false,
12
+ "is_decoder": false,
13
+ "cross_attention_hidden_size": null,
14
+ "add_cross_attention": false,
15
+ "tie_encoder_decoder": false,
16
+ "max_length": 20,
17
+ "min_length": 0,
18
+ "do_sample": false,
19
+ "early_stopping": false,
20
+ "num_beams": 1,
21
+ "num_beam_groups": 1,
22
+ "diversity_penalty": 0.0,
23
+ "temperature": 1.0,
24
+ "top_k": 50,
25
+ "top_p": 1.0,
26
+ "typical_p": 1.0,
27
+ "repetition_penalty": 1.0,
28
+ "length_penalty": 1.0,
29
+ "no_repeat_ngram_size": 0,
30
+ "encoder_no_repeat_ngram_size": 0,
31
+ "bad_words_ids": null,
32
+ "num_return_sequences": 1,
33
+ "chunk_size_feed_forward": 0,
34
+ "output_scores": false,
35
+ "return_dict_in_generate": false,
36
+ "forced_bos_token_id": null,
37
+ "forced_eos_token_id": null,
38
+ "remove_invalid_values": false,
39
+ "exponential_decay_length_penalty": null,
40
+ "suppress_tokens": null,
41
+ "begin_suppress_tokens": null,
42
+ "architectures": [
43
+ "BertForMaskedLM"
44
+ ],
45
+ "finetuning_task": null,
46
+ "id2label": {
47
+ "0": "LABEL_0",
48
+ "1": "LABEL_1"
49
+ },
50
+ "label2id": {
51
+ "LABEL_0": 0,
52
+ "LABEL_1": 1
53
+ },
54
+ "tokenizer_class": null,
55
+ "prefix": null,
56
+ "bos_token_id": null,
57
+ "pad_token_id": 0,
58
+ "eos_token_id": null,
59
+ "sep_token_id": null,
60
+ "decoder_start_token_id": null,
61
+ "task_specific_params": null,
62
+ "problem_type": null,
63
+ "_name_or_path": "",
64
+ "transformers_version": "4.35.2",
65
+ "gradient_checkpointing": false,
66
+ "model_type": "bert",
67
+ "vocab_size": 30522,
68
+ "hidden_size": 768,
69
+ "num_hidden_layers": 12,
70
+ "num_attention_heads": 12,
71
+ "hidden_act": "gelu",
72
+ "intermediate_size": 3072,
73
+ "hidden_dropout_prob": 0.1,
74
+ "attention_probs_dropout_prob": 0.1,
75
+ "max_position_embeddings": 512,
76
+ "type_vocab_size": 2,
77
+ "initializer_range": 0.02,
78
+ "layer_norm_eps": 1e-12,
79
+ "position_embedding_type": "absolute",
80
+ "use_cache": true,
81
+ "classifier_dropout": null
82
+ }
data_loader.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, List, Union
4
+ import sys
5
+
6
+ project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ sys.path.append(project_dir)
8
+
9
+ from utils import structure_data
10
+
11
+
12
+ def load_dataset(dataset_name: str) -> Dict[str, Union[str, List[str]]]:
13
+ """
14
+ Load training dataset or validation dataset.
15
+
16
+ Args:
17
+ dataset_name (str): The name of the dataset. Should be either 'training_dataset' or 'validation_dataset'.
18
+
19
+ Returns:
20
+ dataset (Dict[str, Union[str. List[str]]]): A dictionary representing the
21
+ loaded dataset with keys 'text', 'ner', and 'intent'.
22
+
23
+ Raises:
24
+ ValueError: If the provided dataset_name is not one of the valid_names.
25
+ FileNotFoundError: If the dataset file is not found in the specified path.
26
+ """
27
+
28
+ valid_names = ["training_dataset", "validation_dataset"]
29
+
30
+ if dataset_name not in valid_names:
31
+ raise ValueError(f"Invalid dataset name. Expected one of {valid_names}, got {dataset_name}")
32
+
33
+ path = f"{dataset_name}.json"
34
+
35
+ if not os.path.exists(path):
36
+ raise FileNotFoundError(f"Dataset file not found at {path}")
37
+
38
+ with open(path, 'r') as f:
39
+ dataset = json.load(f)
40
+
41
+ dataset = structure_data(dataset)
42
+
43
+ return dataset
data_module.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from torch.utils.data import DataLoader
3
+ import torch
4
+ from typing import Dict
5
+
6
+
7
+ class DataModule(pl.LightningDataModule):
8
+ """
9
+ Lightning DataModule for handling training and validation datasets.
10
+
11
+ Args:
12
+ training_set (torch.utils.data.Dataset): Training dataset.
13
+ validation_set (torch.utils.data.Dataset): Validation dataset.
14
+
15
+ Attributes:
16
+ training_set (torch.utils.data.Dataset): Training dataset.
17
+ validation_set (torch.utils.data.Dataset): Validation dataset.
18
+ train_ds (torch.utils.data.Dataset): Alias for the training dataset during setup.
19
+ val_ds (torch.utils.data.Dataset): Alias for the validation dataset during setup.
20
+
21
+ Methods:
22
+ setup(self, stage: Optional[str] = None):
23
+ Setup method to load and preprocess datasets.
24
+
25
+ train_dataloader(self) -> DataLoader:
26
+ Return a DataLoader for the training dataset.
27
+
28
+ val_dataloader(self) -> DataLoader:
29
+ Return a DataLoader for the validation dataset.
30
+ """
31
+ def __init__(self, training_set, validation_set):
32
+ super().__init__()
33
+ self.training_set = training_set
34
+ self.validation_set = validation_set
35
+
36
+ def setup(self, stage: str):
37
+ self.train_ds = self.training_set
38
+ self.val_ds = self.validation_set
39
+
40
+ def train_dataloader(self):
41
+ return DataLoader(self.train_ds, batch_size=1, shuffle=True)
42
+
43
+ def val_dataloader(self):
44
+ return DataLoader(self.val_ds, batch_size=1, shuffle=False)
data_processing.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Dataset
2
+ import torch
3
+ from transformers import BertTokenizerFast, BertModel
4
+ from transformers import BertConfig, BertPreTrainedModel
5
+ import numpy as np
6
+ from typing import Dict, List, Union, Tuple
7
+ from utils import ner_labels_to_ids, intent_labels_to_ids, structure_data
8
+
9
+ class tokenized_dataset(Dataset):
10
+ """
11
+ A Pytorch Dataset for tokenizing and encoding text data for a BERT-based model.
12
+
13
+ Args:
14
+ dataset (dict): A dictionary containing 'text', 'ner', and 'intent' keys.
15
+ tokenizer (BertTokenizerFast): A tokenizer for processing text input.
16
+ max_len (int, optionl): Maximum length of tokenized sequences (default: 128).
17
+
18
+ Attributes:
19
+ len (int): Number of samples in the dataset.
20
+
21
+ Methods:
22
+ __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
23
+ Retrieve and preprocess a single sample from the dataset.
24
+
25
+ __len__(self) -> int:
26
+ Get the total number of samples int the dataset.
27
+
28
+ Returns:
29
+ Dict[str, torch.Tensor]: A dictionary containing tokenized and encoded text, NER and intent labels.
30
+ """
31
+ def __init__(self, dataset: Dict[str, List[str]], tokenizer: BertTokenizerFast, max_len: int = 128):
32
+ self.len = len(dataset['text'])
33
+ self.ner_labels_to_ids = ner_labels_to_ids()
34
+ self.intent_labels_to_ids = intent_labels_to_ids()
35
+ self.text = dataset['text']
36
+ self.intent = dataset['intent']
37
+ self.ner = dataset['entities']
38
+ self.tokenizer = tokenizer()
39
+ self.max_len = max_len
40
+
41
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
42
+ # step 1: get the sentence, ner label, and intent_label
43
+ sentence = self.text[index].strip()
44
+ intent_label = self.intent[index].strip()
45
+ ner_labels = self.ner[index]
46
+
47
+ # step 2: use tokenizer to encode a sentence (includes padding/truncation up to max length)
48
+ # BertTokenizerFast provides a handy "return_offsets_mapping" which highlights where each token starts and ends
49
+ encoding = self.tokenizer(
50
+ sentence,
51
+ return_offsets_mapping=True,
52
+ padding='max_length',
53
+ truncation=True,
54
+ max_length=self.max_len
55
+ )
56
+
57
+ # step 3: create ner token labels only for first word pieces of each tokenized word
58
+ tokenized_ner_labels = [self.ner_labels_to_ids[label] for label in ner_labels]
59
+ # create an empty array of -100 of length max_length
60
+ encoded_ner_labels = np.ones(len(encoding['offset_mapping']), dtype=int) * -100
61
+
62
+ # set only labels whose first offset position is 0 and the second is not 0
63
+ i = 0
64
+ prev = -1
65
+ for idx, mapping in enumerate(encoding['offset_mapping']):
66
+ if mapping[0] == mapping[1] == 0:
67
+ continue
68
+ if mapping[0] != prev:
69
+ # overwrite label
70
+ encoded_ner_labels[idx] = tokenized_ner_labels[i]
71
+ prev = mapping[1]
72
+ i += 1
73
+ else:
74
+ prev = mapping[1]
75
+
76
+ # create intent token labels
77
+ tokenized_intent_label = self.intent_labels_to_ids[intent_label]
78
+
79
+ # step 4: turn everything into Pytorch tensors
80
+ item = {key: torch.as_tensor(val) for key, val in encoding.items()}
81
+ item['ner_labels'] = torch.as_tensor(encoded_ner_labels)
82
+ item['intent_labels'] = torch.as_tensor(tokenized_intent_label)
83
+
84
+ return item
85
+
86
+ def __len__(self) -> int:
87
+ return self.len
88
+
inference.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+
5
+ project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
+ sys.path.append(project_dir)
7
+
8
+ from model import MultiTaskBertModel
9
+ from data_loader import load_dataset
10
+ from utils import bert_config, tokenizer, intent_ids_to_labels, intent_labels_to_ids
11
+
12
+
13
+ def load_model(model_path):
14
+ """
15
+ Load the pre-trained model weights from the specified path.
16
+
17
+ Args:
18
+ model_path (str): Path to the pre-trained model weights.
19
+
20
+ Returns:
21
+ model (MultiTaskBertModel): Loaded model with pre-trained weights.
22
+ """
23
+ # Initialize model with configuration and dataset information
24
+ config = bert_config()
25
+ dataset = load_dataset("training_dataset")
26
+ model = MultiTaskBertModel(config, dataset)
27
+
28
+ # Load the model weights from the specified path
29
+ model.load_state_dict(torch.load(model_path))
30
+
31
+ model.eval()
32
+
33
+ return model
34
+
35
+ def preprocess_input(input_data):
36
+ """
37
+ Preprocess the input text data for inference.
38
+
39
+ Args:
40
+ input_data (str): Input text data to be preprocessed.
41
+
42
+ Returns:
43
+ input_ids (torch.Tensor): Tensor of input IDs after tokenization.
44
+ attention_mask (torch.Tensor): Tensor of attention mask indicating input tokens.
45
+ offset_mapping (torch.Tensor): Tensor of offset mappings for input tokens.
46
+ """
47
+ # Tokenize the input text and get offset mappings
48
+ tok = tokenizer()
49
+ preprocessed_input = tok(input_data,
50
+ return_offsets_mapping=True,
51
+ padding='max_length',
52
+ truncation=True,
53
+ max_length=128)
54
+
55
+ # Convert preprocessed inputs to PyTorch tensors
56
+ input_ids = torch.tensor([preprocessed_input['input_ids']])
57
+ attention_mask = torch.tensor([preprocessed_input['attention_mask']])
58
+ offset_mapping = torch.tensor(preprocessed_input['offset_mapping'])
59
+ return input_ids, attention_mask, offset_mapping
60
+
61
+ def perform_inference(model, input_ids, attention_mask):
62
+
63
+ with torch.no_grad():
64
+
65
+ ner_logits, intent_logits = model.forward(input_ids, attention_mask)
66
+
67
+ return ner_logits, intent_logits
68
+
69
+ def align_ner_predictions_with_input(predictions, offset_mapping, input_text):
70
+ aligned_predictions = []
71
+ current_word_idx = 0
72
+
73
+ # Iterate through each prediction and its offset mapping
74
+ for prediction, (start, end) in zip(predictions, offset_mapping):
75
+ if start == end:
76
+ continue
77
+ # Find the corresponding word in the input text
78
+ word = input_text[start:end]
79
+
80
+ # Check if the current word is a special token or part of padding
81
+ if not word.strip():
82
+ continue
83
+
84
+ # Assign the prediction to the word
85
+ aligned_predictions.append((word, prediction))
86
+
87
+ return aligned_predictions
88
+
89
+ def convert_intent_to_label(intent_logit):
90
+ labels = intent_labels_to_ids()
91
+ intent_labels = intent_ids_to_labels(labels)
92
+ return intent_labels[int(intent_logit)]
93
+
94
+
95
+ def main(input_data):
96
+ """
97
+ Main function to perform inference using the pre-trained model.
98
+ """
99
+ # Load the pre-trained model
100
+ model_path = "pytorch_model.bin"
101
+ model = load_model(model_path)
102
+
103
+ # Preprocess the input text
104
+ input_ids, attention_mask, offset_mapping = preprocess_input(input_data)
105
+
106
+ # Perform inference using the pre-trained model
107
+ ner_logits, intent_logits = perform_inference(model, input_ids, attention_mask)
108
+
109
+ # Post-process the model outputs and print the results
110
+ ner_logits = torch.argmax(ner_logits.view(-1, 9), dim=1)
111
+ intent_logits = torch.argmax(intent_logits)
112
+
113
+ ner_logits = align_ner_predictions_with_input(ner_logits, offset_mapping, input_data)
114
+ intent_label = convert_intent_to_label(intent_logits)
115
+
116
+ return ner_logits, intent_label
117
+
118
+
119
+ if __name__ == "__main__":
120
+
121
+ input_data = "I want to schedule a meeting for the 15th of this month at 2:30 PM."
122
+ ner_logits, intent_label = main(input_data)
123
+
124
+ print(f"Ner logits: {ner_logits}")
125
+ print(f"Intent logits: {intent_label}")
intent_labels.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"'Schedule Appointment'": 0, "'Schedule Meeting'": 1, "'Set Alarm'": 2, "'Set Reminder'": 3, "'Set Timer'": 4}
metrics.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics import Metric
3
+
4
+ class MyAccuracy(Metric):
5
+ """
6
+ Accuracy metric costomized for handling sequences with padding.
7
+
8
+ Methods:
9
+ update(self, logits, labels, num_labels): Update the accuracy based on
10
+ model predictions and ground truth labels.
11
+
12
+ compute(self): Compute the accuracy.
13
+
14
+ Attributes:
15
+ total (torch.Tensor): Total number of non-padding elements.
16
+ correct (torch.Tensor): Number of correctly predicted non-padding elements.
17
+ """
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum')
21
+ self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum')
22
+
23
+ def update(self, logits: torch.Tensor, labels: torch.Tensor, num_labels: int) -> None:
24
+ """
25
+ Args:
26
+ logits (torch.Tensor): Model predictions.
27
+ labels (torch.Tensor): Ground truth labels.
28
+ num_labels (int): Number of unique labels.
29
+ """
30
+ flattened_targets = labels.view(-1) # shape (batch_size, sequence_len)
31
+ active_logits = logits.view(-1, num_labels) # shape (batch_size * sequence_len, num_labels)
32
+ flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * sequence_len)
33
+
34
+ # compute accuracy only at active labels
35
+ active_accuracy = labels.view(-1) != -100 # shape (batch_size, sequnce_len)
36
+ ac_labels = torch.masked_select(flattened_targets, active_accuracy)
37
+ predictions = torch.masked_select(flattened_predictions, active_accuracy)
38
+
39
+ self.correct += torch.sum(ac_labels == predictions)
40
+ self.total += torch.numel(ac_labels)
41
+
42
+ def compute(self) -> torch.Tensor:
43
+ """
44
+ Calculate the accuracy.
45
+ """
46
+ return self.correct.float() / self.total.float()
model.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel
2
+ import torch
3
+ import onnx
4
+ import pytorch_lightning as pl
5
+ import wandb
6
+ from metrics import MyAccuracy
7
+ from utils import num_unique_labels
8
+ from typing import Dict, Tuple, List, Optional
9
+
10
+ class MultiTaskBertModel(pl.LightningModule):
11
+
12
+ """
13
+ Multi-task Bert model for Named Entity Recognition (NER) and Intent Classification
14
+
15
+ Args:
16
+ config (BertConfig): Bert model configuration.
17
+ dataset (Dict[str, Union[str, List[str]]]): A dictionary containing keys 'text', 'ner', and 'intent'.
18
+ """
19
+
20
+ def __init__(self, config, dataset):
21
+ super().__init__()
22
+
23
+ self.num_ner_labels, self.num_intent_labels = num_unique_labels(dataset)
24
+
25
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
26
+
27
+ self.model = BertModel(config=config)
28
+
29
+ self.ner_classifier = torch.nn.Linear(config.hidden_size, self.num_ner_labels)
30
+ self.intent_classifier = torch.nn.Linear(config.hidden_size, self.num_intent_labels)
31
+
32
+ # log hyperparameters
33
+ self.save_hyperparameters()
34
+
35
+ self.accuracy = MyAccuracy()
36
+
37
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
38
+
39
+ """
40
+ Perform a forward pass through Multi-task Bert model.
41
+
42
+ Args:
43
+ input_ids (torch.Tensor, torch.shape: (batch, length_of_tokenized_sequences)): Input token IDs.
44
+ attention_mask (Optional[torch.Tensor]): Attention mask for input tokens.
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor,torch.Tensor]: NER logits, Intent logits.
48
+ """
49
+
50
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
51
+
52
+ sequence_output = outputs[0]
53
+ sequence_output = self.dropout(sequence_output)
54
+ ner_logits = self.ner_classifier(sequence_output)
55
+
56
+ pooled_output = outputs[1]
57
+ pooled_output = self.dropout(pooled_output)
58
+ intent_logits = self.intent_classifier(pooled_output)
59
+
60
+ return ner_logits, intent_logits
61
+
62
+ def training_step(self: pl.LightningModule, batch, batch_idx: int) -> torch.Tensor:
63
+ """
64
+ Perform a training step for the Multi-task BERT model.
65
+
66
+ Args:
67
+ batch: Input batch.
68
+ batch_idx (int): Index of the batch.
69
+
70
+ Returns:
71
+ torch.Tensor: Loss value
72
+ """
73
+ loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
74
+ accuracy_ner = self.accuracy(ner_logits, ner_labels, self.num_ner_labels)
75
+ accuracy_intent = self.accuracy(intent_logits, intent_labels, self.num_intent_labels)
76
+ self.log_dict({'training_loss': loss, 'ner_accuracy': accuracy_ner, 'intent_accuracy': accuracy_intent},
77
+ on_step=False, on_epoch=True, prog_bar=True)
78
+ return loss
79
+
80
+ def on_validation_epoch_start(self):
81
+ self.validation_step_outputs_ner = []
82
+ self.validation_step_outputs_intent = []
83
+
84
+ def validation_step(self, batch, batch_idx: int) -> torch.Tensor:
85
+ """
86
+ Perform a validation step for the Multi-task BERT model.
87
+
88
+ Args:
89
+ batch: Input batch.
90
+ batch_idx (int): Index of the batch.
91
+
92
+ Returns:
93
+ torch.Tensor: Loss value.
94
+ """
95
+ loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
96
+ # self.log('val_loss', loss)
97
+ accuracy_ner = self.accuracy(ner_logits, ner_labels, self.num_ner_labels)
98
+ accuracy_intent = self.accuracy(intent_logits, intent_labels, self.num_intent_labels)
99
+ self.log_dict({'validation_loss': loss, 'val_ner_accuracy': accuracy_ner, 'val_intent_accuracy': accuracy_intent},
100
+ on_step=False, on_epoch=True, prog_bar=True)
101
+
102
+ self.validation_step_outputs_ner.append(ner_logits)
103
+ self.validation_step_outputs_intent.append(intent_logits)
104
+ return loss
105
+
106
+ def on_validation_epoch_end(self):
107
+ """
108
+ Perform actions at the end of validation epoch to track the training process in WandB.
109
+ """
110
+ validation_step_outputs_ner = self.validation_step_outputs_ner
111
+ validation_step_outputs_intent = self.validation_step_outputs_intent
112
+
113
+ dummy_input = torch.zeros((1, 128), device=self.device, dtype=torch.long)
114
+ model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
115
+ torch.onnx.export(self, dummy_input, model_filename)
116
+ artifact = wandb.Artifact(name="model.ckpt", type="model")
117
+ artifact.add_file(model_filename)
118
+ self.logger.experiment.log_artifact(artifact)
119
+
120
+ flattened_logits_ner = torch.flatten(torch.cat(validation_step_outputs_ner))
121
+ flattened_logits_intent = torch.flatten(torch.cat(validation_step_outputs_intent))
122
+ self.logger.experiment.log(
123
+ {"valid/ner_logits": wandb.Histogram(flattened_logits_ner.to('cpu')),
124
+ "valid/intent_logits": wandb.Histogram(flattened_logits_intent.to('cpu')),
125
+ "global_step": self.global_step}
126
+ )
127
+
128
+ def _common_step(self, batch, batch_idx):
129
+ """
130
+ Common steps for both training and validation. Calculate loss for both NER and intent layer.
131
+
132
+ Returns:
133
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
134
+ Combiner loss value, NER logits, intent logits, NER labels, intent labels.
135
+ """
136
+ ids = batch['input_ids']
137
+ mask = batch['attention_mask']
138
+ ner_labels = batch['ner_labels']
139
+ intent_labels = batch['intent_labels']
140
+
141
+ ner_logits, intent_logits = self.forward(input_ids=ids, attention_mask=mask)
142
+
143
+ criterion = torch.nn.CrossEntropyLoss()
144
+
145
+ ner_loss = criterion(ner_logits.view(-1, self.num_ner_labels), ner_labels.view(-1).long())
146
+ intent_loss = criterion(intent_logits.view(-1, self.num_intent_labels), intent_labels.view(-1).long())
147
+
148
+ loss = ner_loss + intent_loss
149
+ return loss, ner_logits, intent_logits, ner_labels, intent_labels
150
+
151
+ def configure_optimizers(self):
152
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
153
+ return optimizer
ner_labels.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"O": 0, "B-DATE": 1, "I-DATE": 2, "B-TIME": 3, "I-TIME": 4, "B-TASK": 5, "I-TASK": 6, "B-DUR": 7, "I-DUR": 8}
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbd376379912824c97a8c347e155db8e458526183f8939c2c6b2b780ea8698cc
3
+ size 438053110
requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.1
2
+ aiosignal==1.3.1
3
+ anyio==3.7.1
4
+ appdirs==1.4.4
5
+ argon2-cffi==21.3.0
6
+ argon2-cffi-bindings==21.2.0
7
+ arrow==1.2.3
8
+ asttokens==2.2.1
9
+ async-lru==2.0.4
10
+ attrs==23.1.0
11
+ Babel==2.12.1
12
+ backcall==0.2.0
13
+ beautifulsoup4==4.12.2
14
+ bleach==6.0.0
15
+ blinker==1.6.2
16
+ certifi==2023.5.7
17
+ cffi==1.15.1
18
+ charset-normalizer==3.2.0
19
+ click==8.1.7
20
+ colorama==0.4.6
21
+ coloredlogs==15.0.1
22
+ comm==0.1.4
23
+ contourpy==1.1.0
24
+ cvxopt==1.3.2
25
+ cvxpy==1.3.2
26
+ cycler==0.11.0
27
+ debugpy==1.6.7
28
+ decorator==5.1.1
29
+ defusedxml==0.7.1
30
+ docker-pycreds==0.4.0
31
+ ecos==2.0.12
32
+ executing==1.2.0
33
+ fastjsonschema==2.18.0
34
+ filelock==3.12.4
35
+ Flask==2.3.3
36
+ flatbuffers==23.5.26
37
+ fonttools==4.42.0
38
+ fqdn==1.5.1
39
+ frozenlist==1.4.1
40
+ fsspec==2023.9.2
41
+ gitdb==4.0.11
42
+ GitPython==3.1.41
43
+ huggingface-hub==0.17.2
44
+ humanfriendly==10.0
45
+ hypothesis==6.97.1
46
+ idna==3.4
47
+ iniconfig==2.0.0
48
+ ipykernel==6.25.0
49
+ ipython==8.14.0
50
+ isoduration==20.11.0
51
+ itsdangerous==2.1.2
52
+ jedi==0.19.0
53
+ Jinja2==3.1.2
54
+ joblib==1.3.1
55
+ json5==0.9.14
56
+ jsonpointer==2.4
57
+ jsonschema==4.18.6
58
+ jsonschema-specifications==2023.7.1
59
+ jupyter-events==0.7.0
60
+ jupyter-lsp==2.2.0
61
+ jupyter_client==8.3.0
62
+ jupyter_core==5.3.1
63
+ jupyter_server==2.7.0
64
+ jupyter_server_terminals==0.4.4
65
+ jupyterlab==4.0.4
66
+ jupyterlab-pygments==0.2.2
67
+ jupyterlab_server==2.24.0
68
+ kiwisolver==1.4.4
69
+ lightning==2.1.3
70
+ lightning-utilities==0.10.1
71
+ lxml==4.9.3
72
+ MarkupSafe==2.1.3
73
+ matplotlib==3.7.2
74
+ matplotlib-inline==0.1.6
75
+ mistune==3.0.1
76
+ mpmath==1.3.0
77
+ multidict==6.0.4
78
+ nbclient==0.8.0
79
+ nbconvert==7.7.3
80
+ nbformat==5.9.2
81
+ nest-asyncio==1.5.7
82
+ networkx==3.2.1
83
+ nnfs==0.5.1
84
+ notebook_shim==0.2.3
85
+ numpy==1.25.1
86
+ onnx==1.15.0
87
+ onnxruntime==1.17.0
88
+ osqp==0.6.3
89
+ overrides==7.4.0
90
+ packaging==23.1
91
+ pandas==2.0.3
92
+ pandocfilters==1.5.0
93
+ parso==0.8.3
94
+ pickleshare==0.7.5
95
+ Pillow==10.0.0
96
+ platformdirs==3.10.0
97
+ pluggy==1.4.0
98
+ praw==7.7.1
99
+ prawcore==2.4.0
100
+ prometheus-client==0.17.1
101
+ prompt-toolkit==3.0.39
102
+ protobuf==4.25.2
103
+ psutil==5.9.5
104
+ pure-eval==0.2.2
105
+ pyarrow==14.0.0
106
+ pycparser==2.21
107
+ pygame==2.5.0
108
+ Pygments==2.16.1
109
+ pyparsing==3.0.9
110
+ PyPDF2==3.0.1
111
+ pyreadline3==3.4.1
112
+ pytest==8.0.0
113
+ python-dateutil==2.8.2
114
+ python-docx==1.1.0
115
+ python-json-logger==2.0.7
116
+ pytorch-lightning==2.1.3
117
+ pytz==2023.3
118
+ pywin32==306
119
+ pywinpty==2.0.11
120
+ PyYAML==6.0.1
121
+ pyzmq==25.1.0
122
+ qdldl==0.1.7.post0
123
+ referencing==0.30.2
124
+ regex==2023.8.8
125
+ requests==2.31.0
126
+ rfc3339-validator==0.1.4
127
+ rfc3986-validator==0.1.1
128
+ rpds-py==0.9.2
129
+ safetensors==0.3.3
130
+ scikit-learn==1.3.0
131
+ scipy==1.11.1
132
+ scs==3.2.3
133
+ seaborn==0.12.2
134
+ Send2Trash==1.8.2
135
+ sentry-sdk==1.39.2
136
+ setproctitle==1.3.3
137
+ six==1.16.0
138
+ smmap==5.0.1
139
+ sniffio==1.3.0
140
+ sortedcontainers==2.4.0
141
+ soupsieve==2.4.1
142
+ stack-data==0.6.2
143
+ sympy==1.12
144
+ terminado==0.17.1
145
+ threadpoolctl==3.2.0
146
+ tinycss2==1.2.1
147
+ tokenizers==0.13.3
148
+ torch==2.1.2
149
+ torchaudio==2.1.2
150
+ torchmetrics==1.3.0.post0
151
+ torchvision==0.16.2
152
+ tornado==6.3.2
153
+ tqdm==4.66.1
154
+ traitlets==5.9.0
155
+ transformers==4.33.2
156
+ typing_extensions==4.8.0
157
+ tzdata==2023.3
158
+ update-checker==0.18.0
159
+ uri-template==1.3.0
160
+ urllib3==2.0.4
161
+ wandb==0.16.2
162
+ wcwidth==0.2.6
163
+ webcolors==1.13
164
+ webencodings==0.5.1
165
+ websocket-client==1.6.1
166
+ Werkzeug==2.3.7
167
+ windows-curses==2.3.1
168
+ yarl==1.9.4
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["tokenizer/tokenizer_config.json", "tokenizer/special_tokens_map.json", "tokenizer/vocab.txt", "tokenizer/added_tokens.json", "tokenizer/tokenizer.json"]
training_dataset.json ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "text": "Set a timer for 10 minutes.",
4
+ "intent": "'Set Timer'",
5
+ "entities": "O O O O B-DUR I-DUR"
6
+ },
7
+ {
8
+ "text": "Remind me about the meeting at 3 PM tomorrow.",
9
+ "intent": "'Set Reminder'",
10
+ "entities": "O O O O O O B-TIME I-TIME B-DATE"
11
+ },
12
+ {
13
+ "text": "Schedule an appointment for next Friday at 9 AM.",
14
+ "intent": "'Schedule Appointment'",
15
+ "entities": "O O O O B-DATE I-DATE O B-TIME I-TIME"
16
+ },
17
+ {
18
+ "text": "Can you set a reminder for my doctor's appointment on Monday?",
19
+ "intent": "'Set Reminder'",
20
+ "entities": "O O O O O O O O O O B-DATE"
21
+ },
22
+ {
23
+ "text": "I want to schedule a meeting for the 15th of this month at 2:30 PM.",
24
+ "intent": "'Schedule Meeting'",
25
+ "entities": "O O O O O O O O B-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME"
26
+ },
27
+ {
28
+ "text": "Set an alarm for 7 AM.",
29
+ "intent": "'Set Alarm'",
30
+ "entities": "O O O O B-TIME I-TIME"
31
+ },
32
+ {
33
+ "text": "Remind me to call John in 30 minutes.",
34
+ "intent": "'Set Reminder'",
35
+ "entities": "O O O B-TASK I-TASK O B-DUR I-DUR"
36
+ },
37
+ {
38
+ "text": "\"Schedule a meeting for next Wednesday afternoon.\"",
39
+ "intent": "'Schedule Meeting'",
40
+ "entities": "O O O O B-DATE I-DATE B-TIME"
41
+ },
42
+ {
43
+ "text": "Can you set a timer for cooking for 1 hour?",
44
+ "intent": "'Set Timer'",
45
+ "entities": "O O O O O O B-TASK O B-DUR I-DUR"
46
+ },
47
+ {
48
+ "text": "Remind me about the project deadline at 5 PM on Friday.",
49
+ "intent": "'Set Reminder'",
50
+ "entities": "O O O O B-TASK I-TASK O B-TIME I-TIME O B-DATE"
51
+ },
52
+ {
53
+ "text": "Schedule a doctor's appointment for March 20th at 10:30 AM.",
54
+ "intent": "'Schedule Appointment'",
55
+ "entities": "O O O O O B-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME"
56
+ },
57
+ {
58
+ "text": "Set a timer for a 15-minute break.",
59
+ "intent": "'Set Timer'",
60
+ "entities": "O O O O O B-DUR I-DUR I-DUR B-TASK"
61
+ },
62
+ {
63
+ "text": "Remind me to buy groceries tomorrow morning.",
64
+ "intent": "'Set Reminder'",
65
+ "entities": "O O O B-TASK I-TASK B-DATE B-TIME"
66
+ },
67
+ {
68
+ "text": "Schedule a conference call for the first Monday of next month at 3 PM.",
69
+ "intent": "'Schedule Meeting'",
70
+ "entities": "O O O O O O B-DATE I-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME"
71
+ },
72
+ {
73
+ "text": "Can you remind me to send the report at 4:30 PM today?",
74
+ "intent": "'Set Reminder'",
75
+ "entities": "O O O O O B-TASK I-TASK I-TASK O B-TIME I-TIME I-TIME I-TIME B-DATE"
76
+ },
77
+ {
78
+ "text": "Set a timer for a 20-minute workout session.",
79
+ "intent": "'Set Timer'",
80
+ "entities": "O O O O O B-DUR I-DUR I-DUR B-TASK I-TASK"
81
+ },
82
+ {
83
+ "text": "Remind me to water the plants every Tuesday and Thursday at 9 AM.",
84
+ "intent": "'Set Reminder'",
85
+ "entities": "O O O B-TASK I-TASK I-TASK O B-DATE O B-DATE O B-TIME I-TIME"
86
+ },
87
+ {
88
+ "text": "Schedule a team meeting for next Monday morning at 10:30.",
89
+ "intent": "'Schedule Meeting'",
90
+ "entities": "O O O O O B-DATE I-DATE B-TIME I-TIME I-TIME I-TIME I-TIME"
91
+ },
92
+ {
93
+ "text": "Can you set an alarm for 6:45 AM?",
94
+ "intent": "'Set Alarm'",
95
+ "entities": "O O O O O O B-TIME I-TIME I-TIME I-TIME"
96
+ },
97
+ {
98
+ "text": "Remind me about the webinar in 2 days at 2 PM.",
99
+ "intent": "'Set Reminder'",
100
+ "entities": "O O O O B-TASK O B-DUR I-DUR O B-TIME I-TIME"
101
+ },
102
+ {
103
+ "text": "Schedule a dentist appointment for April 5th at 11:00 in the morning.",
104
+ "intent": "'Schedule Appointment'",
105
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME I-TIME I-TIME"
106
+ },
107
+ {
108
+ "text": "Set a timer for a 5-minute meditation session.",
109
+ "intent": "'Set Timer'",
110
+ "entities": "O O O O O B-DUR I-DUR I-DUR B-TASK I-TASK"
111
+ },
112
+ {
113
+ "text": "Remind me to call Sarah next Wednesday afternoon.",
114
+ "intent": "'Set Reminder'",
115
+ "entities": "O O O B-TASK I-TASK B-DATE I-DATE B-TIME"
116
+ },
117
+ {
118
+ "text": "Schedule a review meeting for the end of the month at 4:30 PM.",
119
+ "intent": "'Schedule Meeting'",
120
+ "entities": "O O B-TASK I-TASK O O B-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME"
121
+ },
122
+ {
123
+ "text": "Can you remind me to pay bills on the last day of the month?",
124
+ "intent": "'Set Reminder'",
125
+ "entities": "O O O O O B-TASK I-TASK O O B-DATE I-DATE I-DATE I-DATE I-DATE"
126
+ },
127
+ {
128
+ "text": "Set a timer for 45 minutes for a study session.",
129
+ "intent": "'Set Timer'",
130
+ "entities": "O O O O B-DUR I-DUR O O B-TASK I-TASK"
131
+ },
132
+ {
133
+ "text": "Remind me to pick up the laundry every Friday afternoon.",
134
+ "intent": "'Set Reminder'",
135
+ "entities": "O O O B-TASK I-TASK I-TASK I-TASK O B-DATE B-TIME"
136
+ },
137
+ {
138
+ "text": "Schedule a client meeting for the 10th of next month at 2 PM.",
139
+ "intent": "'Schedule Meeting'",
140
+ "entities": "O O B-TASK I-TASK O O B-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME"
141
+ },
142
+ {
143
+ "text": "Can you set an alarm for 7:30 AM tomorrow?",
144
+ "intent": "'Set Alarm'",
145
+ "entities": "O O O O O O B-TIME I-TIME I-TIME I-TIME B-DATE"
146
+ },
147
+ {
148
+ "text": "Remind me about the presentation at 4 PM today.",
149
+ "intent": "'Set Reminder'",
150
+ "entities": "O O O O B-TASK O B-TIME I-TIME B-DATE"
151
+ },
152
+ {
153
+ "text": "Schedule a doctor's appointment for May 15th in the evening.",
154
+ "intent": "'Schedule Appointment'",
155
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE O O B-TIME"
156
+ },
157
+ {
158
+ "text": "Set a timer for a 10-minute break between study sessions.",
159
+ "intent": "'Set Timer'",
160
+ "entities": "O O O O O B-DUR I-DUR I-DUR O O O O"
161
+ },
162
+ {
163
+ "text": "Remind me to send the report at 9 AM tomorrow.",
164
+ "intent": "'Set Reminder'",
165
+ "entities": "O O O B-TASK I-TASK I-TASK O B-TIME I-TIME B-DATE"
166
+ },
167
+ {
168
+ "text": "Schedule a team lunch for next Friday at noon.",
169
+ "intent": "'Schedule Meeting'",
170
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE O B-TIME"
171
+ },
172
+ {
173
+ "text": "Can you remind me to buy groceries on Saturday afternoon?",
174
+ "intent": "'Set Reminder'",
175
+ "entities": "O O O O O B-TASK I-TASK O B-DATE B-TIME"
176
+ }
177
+ ]
utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizerFast, BertConfig
2
+ from typing import Dict, List, Union, Tuple
3
+
4
+
5
+ def num_unique_labels(dataset: Dict[str, Union[str, List[str]]]) -> Tuple[int, int]:
6
+ """
7
+ Calculate the number of NER labels and INTENT labels in the dataset.
8
+
9
+ Args:
10
+ dataset (dict): A dictionary containing 'text', 'entities' and 'intent' keys.
11
+
12
+ Returns:
13
+ Tuple: Number of unique NER and INTENT lables.
14
+ """
15
+ one_dimensional_ner = [tag for subset in dataset['entities'] for tag in subset]
16
+ return len(set(one_dimensional_ner)), len(set(dataset['intent']))
17
+
18
+ def ner_labels_to_ids() -> Dict[str, int]:
19
+ """
20
+ Map NER labels to corresponding numeric IDs.
21
+
22
+ Returns:
23
+ Dict[str, int]: A dictionary where keys are NER labels, and values are their corresponding IDs.
24
+ """
25
+ labels_to_ids_ner = {
26
+ 'O': 0,
27
+ 'B-DATE': 1,
28
+ 'I-DATE': 2,
29
+ 'B-TIME': 3,
30
+ 'I-TIME': 4,
31
+ 'B-TASK': 5,
32
+ 'I-TASK': 6,
33
+ 'B-DUR': 7,
34
+ 'I-DUR': 8
35
+ }
36
+ return labels_to_ids_ner
37
+
38
+ def ner_ids_to_labels(ner_labels_to_ids) -> Dict[int, str]:
39
+ """
40
+ Map numeric IDs to corresponding NER labels.
41
+
42
+ Returns:
43
+ Dict[int, str]: A dictionary where keys are numeric IDs, and values are their corresponding NER labels.
44
+ """
45
+ ner_ids_to_labels = {v: k for k, v in ner_labels_to_ids.items()}
46
+ return ner_ids_to_labels
47
+
48
+ def intent_labels_to_ids() -> Dict[str, int]:
49
+ """
50
+ Map intent labels to corresponding numeric values.
51
+
52
+ Returns:
53
+ Dict[str, int]: A dictionary where keys are intent labels, and values are their corresponding numeric IDs.
54
+ """
55
+ intent_labels_to_ids = {
56
+ "'Schedule Appointment'": 0,
57
+ "'Schedule Meeting'": 1,
58
+ "'Set Alarm'": 2,
59
+ "'Set Reminder'": 3,
60
+ "'Set Timer'": 4
61
+ }
62
+ return intent_labels_to_ids
63
+
64
+ def intent_ids_to_labels(intent_labels_to_ids) -> Dict[int, str]:
65
+ """
66
+ Map numeric values to corresponding intent labels.
67
+
68
+ Returns:
69
+ Dict[int, str]: A dictionary where keys are numeric IDs, and values are their corresponding intent labels.
70
+ """
71
+ intent_ids_to_labels = {v: k for k, v in intent_labels_to_ids.items()}
72
+ return intent_ids_to_labels
73
+
74
+ def tokenizer() -> BertTokenizerFast:
75
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
76
+ return tokenizer
77
+
78
+ def bert_config() -> BertConfig:
79
+ config = BertConfig.from_pretrained('bert-base-uncased')
80
+ return config
81
+
82
+ def structure_data(dataset):
83
+ structured_data = {'text': [], 'entities': [], 'intent': []}
84
+ for sample in dataset:
85
+ structured_data['text'].append(sample['text'])
86
+ structured_data['entities'].append(sample['entities'].split())
87
+ structured_data['intent'].append(sample['intent'])
88
+ return structured_data
validation_dataset.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "text": "Set a reminder for my dentist appointment next Tuesday at 9 AM.",
4
+ "intent": "'Set Reminder'",
5
+ "entities": "O O O O O B-TASK I-TASK B-DATE I-DATE O B-TIME I-TIME"
6
+ },
7
+ {
8
+ "text": "Can you schedule a meeting for Friday afternoon?",
9
+ "intent": "'Schedule Meeting'",
10
+ "entities": "O O O O O O B-DATE B-TIME"
11
+ },
12
+ {
13
+ "text": "Set an alarm for 7:30 AM to wake up.",
14
+ "intent": "'Set Alarm'",
15
+ "entities": "O O O O B-TIME I-TIME I-TIME O O O"
16
+ },
17
+ {
18
+ "text": "Remind me about the call with the client in 2 hours.",
19
+ "intent": "'Set Reminder'",
20
+ "entities": "O O O O B-TASK I-TASK I-TASK I-TASK O B-TIME I-TIME"
21
+ },
22
+ {
23
+ "text": "Schedule a doctor's appointment for May 10th at 11:45 AM.",
24
+ "intent": "'Schedule Appointment'",
25
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME"
26
+ },
27
+ {
28
+ "text": "Schedule a conference call for next Monday morning at 9:00 sharp.",
29
+ "intent": "'Schedule Meeting'",
30
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE B-TIME I-TIME I-TIME I-TIME I-TIME O"
31
+ },
32
+ {
33
+ "text": "Set a timer for a 5-minute break between study sessions.",
34
+ "intent": "'Set Timer'",
35
+ "entities": "O O O O O B-DUR I-DUR I-DUR O O O O"
36
+ },
37
+ {
38
+ "text": "Remind me to water the plants every Sunday and Wednesday at 7 AM.",
39
+ "intent": "'Set Reminder'",
40
+ "entities": "O O O B-TASK I-TASK I-TASK B-DATE I-DATE O B-DATE O B-TIME I-TIME"
41
+ },
42
+ {
43
+ "text": "Can you schedule a dentist appointment for next month?",
44
+ "intent": "'Schedule Appointment'",
45
+ "entities": "O O O O B-TASK I-TASK O B-DATE I-DATE"
46
+ },
47
+ {
48
+ "text": "Set an alarm for 4:30 PM to walk the dog.",
49
+ "intent": "'Set Alarm'",
50
+ "entities": "O O O O B-TIME I-TIME I-TIME I-TIME O B-TASK I-TASK I-TASK"
51
+ },
52
+ {
53
+ "text": "Set a timer for 20 minutes while I meditate.",
54
+ "intent": "'Set Timer'",
55
+ "entities": "O O O O B-DUR I-DUR O O O"
56
+ },
57
+ {
58
+ "text": "Remind me to pick up dry cleaning this Thursday after work.",
59
+ "intent": "'Set Reminder'",
60
+ "entities": "O O O B-TASK I-TASK I-TASK I-TASK O B-DATE O O"
61
+ },
62
+ {
63
+ "text": "Schedule a team lunch for next Wednesday at noon.",
64
+ "intent": "'Schedule Meeting'",
65
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE O B-TIME"
66
+ },
67
+ {
68
+ "text": "Can you set an alarm for 8:15 AM tomorrow?",
69
+ "intent": "'Set Alarm'",
70
+ "entities": "O O O O O O B-TIME I-TIME I-TIME I-TIME B-DATE"
71
+ },
72
+ {
73
+ "text": "Remind me about the doctor's appointment on the 25th at 10:30 AM.",
74
+ "intent": "'Set Reminder'",
75
+ "entities": "O O O O B-TASK I-TASK O O B-DATE O B-TIME I-TIME I-TIME I-TIME"
76
+ },
77
+ {
78
+ "text": "Remind me to submit the report by next Monday at 5 PM.",
79
+ "intent": "'Set Reminder'",
80
+ "entities": "O O O B-TASK I-TASK I-TASK O B-DATE I-DATE O B-TIME I-TIME"
81
+ },
82
+ {
83
+ "text": "Set a timer for 45 minutes for a cooking session.",
84
+ "intent": "'Set Timer'",
85
+ "entities": "O O O O B-DUR I-DUR O O B-TASK I-TASK"
86
+ },
87
+ {
88
+ "text": "Schedule a client meeting for the 10th of next month at 2:30 PM.",
89
+ "intent": "'Schedule Meeting'",
90
+ "entities": "O O B-TASK I-TASK O O B-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME"
91
+ },
92
+ {
93
+ "text": "Can you set an alarm for 9 AM every weekday?",
94
+ "intent": "'Set Alarm'",
95
+ "entities": "O O O O O O B-TIME I-TIME O B-DATE"
96
+ },
97
+ {
98
+ "text": "Remind me about the appointment with the lawyer on April 15th at 11 AM.",
99
+ "intent": "'Set Reminder'",
100
+ "entities": "O O O O B-TASK I-TASK I-TASK I-TASK O B-DATE I-DATE O B-TIME I-TIME"
101
+ },
102
+ {
103
+ "text": "Schedule a study group session for next Friday evening at 6:30 PM.",
104
+ "intent": "'Schedule Meeting'",
105
+ "entities": "O O B-TASK I-TASK I-TASK O B-DATE I-DATE O O B-TIME I-TIME I-TIME I-TIME"
106
+ },
107
+ {
108
+ "text": "Set a timer for a 10-minute power nap.",
109
+ "intent": "'Set Timer'",
110
+ "entities": "O O O O O B-DUR I-DUR I-DUR O O"
111
+ },
112
+ {
113
+ "text": "Remind me to buy tickets for the concert on Thursday at noon.",
114
+ "intent": "'Set Reminder'",
115
+ "entities": "O O O B-TASK I-TASK O O O O B-DATE O B-TIME"
116
+ },
117
+ {
118
+ "text": "Set an alarm for 7:45 AM to start morning exercises.",
119
+ "intent": "'Set Alarm'",
120
+ "entities": "O O O O B-TIME I-TIME I-TIME I-TIME O O O O"
121
+ }
122
+ ]
vocab.txt ADDED
The diff for this file is too large to render. See raw diff