Spaces:
Sleeping
Sleeping
import torch | |
from transformers import BertForTokenClassification | |
from .config_train import device, model_load_path, tokenizer | |
from .DataProcessing import read_input | |
from .load_data import sorted_tags | |
class Key_Ner_Predictor: | |
def __init__(self, model_path, tokenizer, device, tag_map): | |
""" | |
Initialize the Key_Ner_Predictor with the model, tokenizer, and device. | |
Args: | |
model_path (str): Path to the pre-trained model. | |
tokenizer (BertTokenizer): Tokenizer to process input sentences. | |
device (torch.device): Device to run the model on. | |
tag_map (Dict[int, str]): Mapping of indices to tags. | |
""" | |
self.model = BertForTokenClassification.from_pretrained(model_path).to(device) | |
self.tokenizer = tokenizer | |
self.device = device | |
self.tag_map = tag_map | |
def predict(self, sentence): | |
""" | |
Predict the tags for each token in the given sentence. | |
Args: | |
sentence (str): Input sentence to predict. | |
Returns: | |
Tuple[str, List[str]]: The original sentence and its predicted tags. | |
""" | |
# Process the sentence | |
sentence = read_input(sentence) | |
# Tokenize the sentence | |
input_ids = self.tokenizer.encode(sentence, return_tensors="pt").to(self.device) | |
# Create attention masks | |
attention_masks = (input_ids != self.tokenizer.pad_token_id).float().to(self.device) | |
# Set model to evaluation mode | |
self.model.eval() | |
with torch.no_grad(): | |
# Forward pass | |
outputs = self.model(input_ids, token_type_ids=None, attention_mask=attention_masks) | |
logits = outputs.logits | |
# Get predicted tags for each token in the sentence | |
predicted_tags = torch.argmax(logits, dim=2).cpu().numpy()[0] | |
# Map indices to tags | |
predicted_tags = [self.tag_map[idx] for idx in predicted_tags] | |
predicted_tags = set(predicted_tags) | |
predicted_tags.remove('<pad>') | |
predicted_tags = list(predicted_tags) | |
for index in range(len(predicted_tags)): | |
predicted_tags[index] = predicted_tags[index].replace(" ", "_") | |
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True), predicted_tags | |
# Initialize the Key_Ner_Predictor | |
predictor = Key_Ner_Predictor( | |
model_path=model_load_path, | |
tokenizer=tokenizer, | |
device=device, | |
tag_map=dict(enumerate(sorted_tags)) | |
) | |
# # Define the sentence to predict | |
# sentence = "Tôi muốn đi cắm trại ngắm hoàng hôn trên biển cùng gia đình" | |
# # Get the prediction | |
# original_sentence, predicted_tags = predictor.predict(sentence) | |
# # Print the sentence and its predicted tags | |
# print("Sentence:", original_sentence) | |
# print("Predicted Tags:", predicted_tags) | |