| | from transformers import pipeline |
| | import string |
| |
|
| | |
| | model_checkpoint = "results/checkpoint-16000" |
| | question_answerer = pipeline("question-answering", model = model_checkpoint) |
| |
|
| | def predict(question, context): |
| | answer = question_answerer(question = question, |
| | context = context) |
| | |
| | exclude = set(string.punctuation) |
| | |
| | text = answer['answer'] |
| | text = ''.join(ch for ch in text if ch not in exclude) |
| | |
| | answer['answer'] = text |
| |
|
| | return answer |
| | |
| | if __name__ == '__main__': |
| | question = 'Combi cao bao nhiêu.' |
| | context = 'Combi là sinh viên năm 2 trường Ecole Polytechnique. Chiều cao của Combi là 1m73, cân nặng là 63kg. Combi thích học Machine Learning vì Machine Learning cần nhiều toán.' |
| |
|
| | print(predict(question, context)) |