zhangfan commited on
Commit
0147d71
·
1 Parent(s): 8ad14fc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +33 -0
README.md CHANGED
@@ -39,3 +39,36 @@ The following hyperparameters were used during training:
39
  - Pytorch 1.10.1+cu102
40
  - Datasets 1.17.0
41
  - Tokenizers 0.11.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  - Pytorch 1.10.1+cu102
40
  - Datasets 1.17.0
41
  - Tokenizers 0.11.0
42
+
43
+ ## Usage (HuggingFace Transformers)
44
+ You can use the model like this:
45
+
46
+ ```python
47
+ import torch
48
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
49
+
50
+ # label_list
51
+ label_list = ['matched', 'unmatched']
52
+
53
+ # Load model from HuggingFace Hub
54
+ tokenizer = AutoTokenizer.from_pretrained("Fan-s/reddit-tc-bert", use_fast=True)
55
+ model = AutoModelForSequenceClassification.from_pretrained("Fan-s/reddit-tc-bert")
56
+
57
+ # set the input
58
+ post = "don't make gravy with asbestos."
59
+ response = "i'd expect someone with a culinary background to know that. since we're talking about school dinner ladies, they need to learn this pronto."
60
+
61
+ def predict(post, response, max_seq_length=128):
62
+ with torch.no_grad():
63
+ args = (post, response)
64
+ input = tokenizer(*args, padding="max_length", max_length=max_seq_length, truncation=True, return_tensors="pt")
65
+ output = model(**input)
66
+ logits = output.logits
67
+ item = torch.argmax(logits, dim=1)
68
+ predict_label = label_list[item]
69
+ return predict_label, logits
70
+
71
+ # predict whether the two sentences match
72
+ predict_label = predict(post, response)
73
+ print("predict_label:", predict_label)
74
+ ```