from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
def predict(text): | |
tokenizer = AutoTokenizer.from_pretrained("username/model_name") | |
model = AutoModelForSequenceClassification.from_pretrained("username/model_name") | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted_class_id = logits.argmax().item() | |
return predicted_class_id | |