legacydemo / src /upvote_predictor.py
gupta-amulya's picture
Refactor app.py and requirements.txt; remove SemanticSearcher and update numpy version
e106cc0
raw
history blame
2.56 kB
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from transformers import BertTokenizer
class UpvotePredictor:
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.upvote_ml_model = torch.load(
model_path, map_location=torch.device("cpu"), weights_only=False
)
self.tokenizer = BertTokenizer.from_pretrained(
"bert-base-uncased", do_lower_case=True
)
self.upvote_ml_model.to(self.device)
self.upvote_ml_model.eval()
def get_upvote_prediction(
self, question: str, answer: str, question_context: str = None
) -> int:
llm_response_input_ids = []
llm_response_attention_masks = []
encoded_dict = self.tokenizer.encode_plus(
answer,
add_special_tokens=True,
max_length=256,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
llm_response_input_ids.append(encoded_dict["input_ids"])
llm_response_attention_masks.append(encoded_dict["attention_mask"])
llm_response_input_ids = torch.cat(llm_response_input_ids, dim=0)
llm_response_attention_masks = torch.cat(llm_response_attention_masks, dim=0)
test_dataset = TensorDataset(
llm_response_input_ids, llm_response_attention_masks
)
test_dataloader = DataLoader(
test_dataset, # The validation samples.
sampler=SequentialSampler(test_dataset), # Pull out batches sequentially.
batch_size=1, # Evaluate with this batch size.
)
predictions = []
for batch in test_dataloader:
b_input_ids = batch[0].to(self.device)
b_input_mask = batch[1].to(self.device)
with torch.no_grad():
output = self.upvote_ml_model(
b_input_ids, token_type_ids=None, attention_mask=b_input_mask
)
logits = output.logits
logits = logits.detach().cpu().numpy()
pred_flat = np.argmax(logits, axis=1).flatten()
predictions.extend(list(pred_flat))
if predictions[0] == 0:
return (
"Not credible suggestion",
pd.DataFrame(),
)
else:
return ("Credible suggestion", pd.DataFrame())