legacydemo / src /upvote_predictor.py
gupta-amulya's picture
Enhance SemanticSearcher integration and refine UpvotePredictor output handling
20df6e4
raw
history blame
2.46 kB
import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from transformers import BertTokenizer
class UpvotePredictor:
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.upvote_ml_model = torch.load(
model_path, map_location=torch.device("cpu"), weights_only=False
)
self.tokenizer = BertTokenizer.from_pretrained(
"bert-base-uncased", do_lower_case=True
)
self.upvote_ml_model.to(self.device)
self.upvote_ml_model.eval()
def get_upvote_prediction(
self, question: str, answer: str, question_context: str = None
) -> int:
llm_response_input_ids = []
llm_response_attention_masks = []
encoded_dict = self.tokenizer.encode_plus(
answer,
add_special_tokens=True,
max_length=256,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
llm_response_input_ids.append(encoded_dict["input_ids"])
llm_response_attention_masks.append(encoded_dict["attention_mask"])
llm_response_input_ids = torch.cat(llm_response_input_ids, dim=0)
llm_response_attention_masks = torch.cat(llm_response_attention_masks, dim=0)
test_dataset = TensorDataset(
llm_response_input_ids, llm_response_attention_masks
)
test_dataloader = DataLoader(
test_dataset, # The validation samples.
sampler=SequentialSampler(test_dataset), # Pull out batches sequentially.
batch_size=1, # Evaluate with this batch size.
)
predictions = []
for batch in test_dataloader:
b_input_ids = batch[0].to(self.device)
b_input_mask = batch[1].to(self.device)
with torch.no_grad():
output = self.upvote_ml_model(
b_input_ids, token_type_ids=None, attention_mask=b_input_mask
)
logits = output.logits
logits = logits.detach().cpu().numpy()
pred_flat = np.argmax(logits, axis=1).flatten()
predictions.extend(list(pred_flat))
if predictions[0] == 0:
return "Not credible suggestion"
else:
return "Credible suggestion"