zhangfan
commited on
Commit
·
0147d71
1
Parent(s):
8ad14fc
Update README.md
Browse files
README.md
CHANGED
@@ -39,3 +39,36 @@ The following hyperparameters were used during training:
|
|
39 |
- Pytorch 1.10.1+cu102
|
40 |
- Datasets 1.17.0
|
41 |
- Tokenizers 0.11.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
- Pytorch 1.10.1+cu102
|
40 |
- Datasets 1.17.0
|
41 |
- Tokenizers 0.11.0
|
42 |
+
|
43 |
+
## Usage (HuggingFace Transformers)
|
44 |
+
You can use the model like this:
|
45 |
+
|
46 |
+
```python
|
47 |
+
import torch
|
48 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
49 |
+
|
50 |
+
# label_list
|
51 |
+
label_list = ['matched', 'unmatched']
|
52 |
+
|
53 |
+
# Load model from HuggingFace Hub
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained("Fan-s/reddit-tc-bert", use_fast=True)
|
55 |
+
model = AutoModelForSequenceClassification.from_pretrained("Fan-s/reddit-tc-bert")
|
56 |
+
|
57 |
+
# set the input
|
58 |
+
post = "don't make gravy with asbestos."
|
59 |
+
response = "i'd expect someone with a culinary background to know that. since we're talking about school dinner ladies, they need to learn this pronto."
|
60 |
+
|
61 |
+
def predict(post, response, max_seq_length=128):
|
62 |
+
with torch.no_grad():
|
63 |
+
args = (post, response)
|
64 |
+
input = tokenizer(*args, padding="max_length", max_length=max_seq_length, truncation=True, return_tensors="pt")
|
65 |
+
output = model(**input)
|
66 |
+
logits = output.logits
|
67 |
+
item = torch.argmax(logits, dim=1)
|
68 |
+
predict_label = label_list[item]
|
69 |
+
return predict_label, logits
|
70 |
+
|
71 |
+
# predict whether the two sentences match
|
72 |
+
predict_label = predict(post, response)
|
73 |
+
print("predict_label:", predict_label)
|
74 |
+
```
|