Text2Text Generation
Transformers
PyTorch
Safetensors
Czech
English
mt5
File size: 4,938 Bytes
289fb31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
from typing import List

from adaptor.adapter import Adapter
from adaptor.evaluators.generative import BLEU
from adaptor.lang_module import LangModule
from adaptor.schedules import ParallelSchedule
from adaptor.utils import AdaptationArguments, StoppingStrategy
from datasets import load_dataset

from priming_objective import Priming

training_arguments = AdaptationArguments(output_dir="train_dir_SQuAD_random_large",
                                         learning_rate=2e-5,  # we set LR=2e-4 for pre-training experiments
                                         stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
                                         # stopping_strategy=StoppingStrategy.NUM_STEPS_TOTAL,
                                         do_train=True,
                                         do_eval=True,
                                         warmup_steps=1000,
                                         max_steps=10000,
                                         gradient_accumulation_steps=30,
                                         eval_steps=500,
                                         logging_steps=10,
                                         save_steps=500,
                                         num_train_epochs=5,
                                         evaluation_strategy="steps",
                                         save_total_limit=10,
                                         stopping_patience=10)
eval_examples = 200

# priming
num_demonstrations = 3


def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str:
    return " ".join(previous_examples + [current_example])


lang_module = LangModule("google/mt5-large")

# priming
per_type_examples = {}

qa_en = load_dataset("squad")
qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000)

val_metrics = [BLEU(**{"additional_sep_char": "▁"})]

# SQuAD QA dataset & objective:


def _get_en_qa_categories(data) -> List[str]:
    return [question.split()[0] if not question.startswith("To")
            else " ".join(question.split()[:2])
            for question in data["question"]]


q_answering_en = Priming(lang_module,
                         max_eval_samples=eval_examples,
                         demos_selection_strategy="random",
                         texts_or_path=qa_train["question"],
                         text_pair_or_path=qa_train["context"],
                         val_texts_or_path=qa_en["validation"]["question"][-eval_examples:],
                         val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:],
                         labels_or_path=[a["text"][0] for a in qa_train["answers"]],
                         val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:],
                         train_question_categories=_get_en_qa_categories(qa_train),
                         val_question_categories=_get_en_qa_categories(qa_en["validation"])[-eval_examples:],
                         batch_size=1,
                         val_evaluators=val_metrics,
                         # val_evaluators=val_metrics,
                         source_lang_id="en",
                         objective_id="AQA-en")

# Czech data & objective

squad_cs_dataset = json.load(open("czech_squad_4-sents.json"))

skipped = 0

questions_cs = []
contexts_cs = []
answers_cs = []
categories_cs = []

for i, entry in squad_cs_dataset.items():
    if len(entry["context"]) > 800:
        skipped += 1
        continue

    questions_cs.append(entry["question"])
    contexts_cs.append(entry["context"])
    answers_cs.append(entry["answers"]["text"][0])
    categories_cs.append(entry["answer_type"])

print("Skipped cs examples: %s" % skipped)

q_answering_cs = Priming(lang_module,
                         max_eval_samples=eval_examples,
                         demos_selection_strategy="random",
                         texts_or_path=questions_cs[:-eval_examples],
                         text_pair_or_path=contexts_cs[:-eval_examples],
                         val_texts_or_path=questions_cs[-eval_examples:],
                         val_text_pair_or_path=contexts_cs[-eval_examples:],
                         labels_or_path=answers_cs[:-eval_examples],
                         val_labels_or_path=answers_cs[-eval_examples:],
                         train_question_categories=categories_cs[:-eval_examples],
                         val_question_categories=categories_cs[-eval_examples:],
                         batch_size=1,
                         val_evaluators=val_metrics,
                         source_lang_id="cs",
                         objective_id="SQUAD-cs")

schedule = ParallelSchedule(objectives=[q_answering_en, q_answering_cs],
                            args=training_arguments)

adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()