File size: 3,986 Bytes
78a5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f6ef0
78a5823
 
 
 
 
 
 
b3f6ef0
78a5823
 
 
 
 
 
 
 
 
 
 
 
 
b3f6ef0
 
78a5823
 
 
b3f6ef0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import os
import sys

project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_dir)

from model import MultiTaskBertModel
from data_loader import load_dataset
from utils import bert_config, tokenizer, intent_ids_to_labels, intent_labels_to_ids


def load_model(model_path):
    """
    Load the pre-trained model weights from the specified path.

    Args:
        model_path (str): Path to the pre-trained model weights.

    Returns:
        model (MultiTaskBertModel): Loaded model with pre-trained weights.
    """
    # Initialize model with configuration and dataset information
    config = bert_config()
    dataset = load_dataset("training_dataset")
    model = MultiTaskBertModel(config, dataset)

    # Load the model weights from the specified path
    model.load_state_dict(torch.load(model_path))

    model.eval()

    return model

def preprocess_input(input_data):
    """
    Preprocess the input text data for inference.

    Args:
        input_data (str): Input text data to be preprocessed.

    Returns:
        input_ids (torch.Tensor): Tensor of input IDs after tokenization.
        attention_mask (torch.Tensor): Tensor of attention mask indicating input tokens.
        offset_mapping (torch.Tensor): Tensor of offset mappings for input tokens.
    """
    # Tokenize the input text and get offset mappings
    tok = tokenizer()
    preprocessed_input = tok(input_data,
                                   return_offsets_mapping=True,
                                   padding='max_length',
                                   truncation=True,
                                   max_length=128)

    # Convert preprocessed inputs to PyTorch tensors
    input_ids = torch.tensor([preprocessed_input['input_ids']])
    attention_mask = torch.tensor([preprocessed_input['attention_mask']])
    offset_mapping = torch.tensor(preprocessed_input['offset_mapping'])
    return input_ids, attention_mask, offset_mapping

def perform_inference(model, input_ids, attention_mask):

    with torch.no_grad():

        ner_logits, intent_logits = model.forward(input_ids, attention_mask)

    return ner_logits, intent_logits
    
def align_ner_predictions_with_input(predictions, offset_mapping, input_text):
    aligned_predictions = []
    current_word_idx = 0

    # Iterate through each prediction and its offset mapping
    for prediction, (start, end) in zip(predictions, offset_mapping):
        if start == end:
            continue
        # Find the corresponding word in the input text
        word = input_text[start:end]

        # Check if the current word is a special token or part of padding
        if not word.strip():
            continue

        # Assign the prediction to the word
        aligned_predictions.append((word, prediction))
    
    return aligned_predictions

def convert_intent_to_label(intent_logit):
    labels = intent_labels_to_ids()
    intent_labels = intent_ids_to_labels(labels)
    return intent_labels[int(intent_logit)]


def main():
    """
    Main function to perform inference using the pre-trained model.
    """
    # Load the pre-trained model
    model_path = "pytorch_model.bin"
    model = load_model(model_path)

    input_data = input("Enter the text to analyze: ")
    # Preprocess the input text
    input_ids, attention_mask, offset_mapping = preprocess_input(input_data)

    # Perform inference using the pre-trained model
    ner_logits, intent_logits = perform_inference(model, input_ids, attention_mask)

    # Post-process the model outputs and print the results
    ner_logits = torch.argmax(ner_logits.view(-1, 9), dim=1)
    intent_logits = torch.argmax(intent_logits)

    ner_logits = align_ner_predictions_with_input(ner_logits, offset_mapping, input_data)
    intent_label = convert_intent_to_label(intent_logits)

    print(f"Ner logits: {ner_logits}")
    print(f"Intent logits: {intent_label}")
    

if __name__ == "__main__":
    main()