|
import torch |
|
from datasets import load_dataset |
|
from transformers import ( |
|
Trainer, |
|
T5Config, |
|
T5TokenizerFast, |
|
TrainingArguments, |
|
DataCollatorForSeq2Seq, |
|
T5ForConditionalGeneration |
|
) |
|
|
|
|
|
|
|
base_model = "t5-small" |
|
data_path = "src/data/clean_corpus.jsonl" |
|
tokeniser_path = "src/tokeniser/" |
|
output_dir = "checkpoints/" |
|
|
|
|
|
tokeniser = T5TokenizerFast.from_pretrained(tokeniser_path) |
|
vocab_size = tokeniser.vocab_size |
|
pad_token_id = tokeniser.pad_token_id |
|
|
|
|
|
config = T5Config( |
|
vocab_size = vocab_size, |
|
d_model = 512, |
|
d_ff = 2048, |
|
num_layers = 6, |
|
num_heads = 8, |
|
pad_token_id = pad_token_id, |
|
decoder_start_token_id = pad_token_id |
|
) |
|
|
|
model = T5ForConditionalGeneration(config) |
|
|
|
|
|
def tokenise_function(example: dict) -> T5TokenizerFast: |
|
""" |
|
Simple function to tokenise input data. |
|
""" |
|
inputs = [f"Cyrillic2Latin: {item['src']}" for item in example["transliteration"]] |
|
targets = [item["tgt"] for item in example["transliteration"]] |
|
|
|
model_inputs = tokeniser( |
|
inputs, max_length = 128, truncation = True, padding = "max_length" |
|
) |
|
labels = tokeniser( |
|
targets, max_length = 128, truncation = True, padding = "max_length" |
|
)["input_ids"] |
|
|
|
model_inputs["labels"] = labels |
|
|
|
return model_inputs |
|
|
|
|
|
|
|
dataset = load_dataset("json", data_files = data_path, split = "train") |
|
|
|
|
|
dataset_split = dataset.train_test_split(test_size = 0.25) |
|
train_dataset = dataset_split["train"] |
|
val_dataset = dataset_split["test"] |
|
|
|
|
|
tokenised_train = train_dataset.map(tokenise_function, batched = True, remove_columns = ["transliteration"]) |
|
tokenised_eval = val_dataset.map(tokenise_function, batched = True, remove_columns = ["transliteration"]) |
|
|
|
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer = tokeniser, model = model) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir = output_dir, |
|
overwrite_output_dir = True, |
|
num_train_epochs = 2, |
|
per_device_train_batch_size = 32, |
|
gradient_accumulation_steps = 2, |
|
save_strategy = "steps", |
|
save_steps = 500, |
|
save_total_limit = 3, |
|
eval_strategy = "epoch", |
|
logging_dir = "logs", |
|
fp16 = torch.cuda.is_available() |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model = model, |
|
args = training_args, |
|
train_dataset = tokenised_train, |
|
eval_dataset = tokenised_eval, |
|
data_collator = data_collator, |
|
processing_class = tokeniser |
|
) |
|
|
|
|
|
trainer.train() |
|
model.save_pretrained(output_dir) |
|
tokeniser.save_pretrained(output_dir) |
|
|
|
print("DalaT5 training complete.") |