dalat5 / src /train_t5.py
crossroderick's picture
Fix with the correct model files
96f0b49
import torch
from datasets import load_dataset
from transformers import (
Trainer,
T5Config,
T5TokenizerFast,
TrainingArguments,
DataCollatorForSeq2Seq,
T5ForConditionalGeneration
)
# Path config
base_model = "t5-small"
data_path = "src/data/clean_corpus.jsonl"
tokeniser_path = "src/tokeniser/"
output_dir = "checkpoints/"
# Load tokeniser
tokeniser = T5TokenizerFast.from_pretrained(tokeniser_path)
vocab_size = tokeniser.vocab_size
pad_token_id = tokeniser.pad_token_id
# Use custom vocab size for the model
config = T5Config(
vocab_size = vocab_size,
d_model = 512,
d_ff = 2048,
num_layers = 6,
num_heads = 8,
pad_token_id = pad_token_id,
decoder_start_token_id = pad_token_id
)
model = T5ForConditionalGeneration(config)
def tokenise_function(example: dict) -> T5TokenizerFast:
"""
Simple function to tokenise input data.
"""
inputs = [f"Cyrillic2Latin: {item['src']}" for item in example["transliteration"]]
targets = [item["tgt"] for item in example["transliteration"]]
model_inputs = tokeniser(
inputs, max_length = 128, truncation = True, padding = "max_length"
)
labels = tokeniser(
targets, max_length = 128, truncation = True, padding = "max_length"
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
# Load dataset
dataset = load_dataset("json", data_files = data_path, split = "train")
# Split dataset into train and validation sets (75/25 split)
dataset_split = dataset.train_test_split(test_size = 0.25)
train_dataset = dataset_split["train"]
val_dataset = dataset_split["test"]
# Tokenise datasets
tokenised_train = train_dataset.map(tokenise_function, batched = True, remove_columns = ["transliteration"])
tokenised_eval = val_dataset.map(tokenise_function, batched = True, remove_columns = ["transliteration"])
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer = tokeniser, model = model)
# Training args
training_args = TrainingArguments(
output_dir = output_dir,
overwrite_output_dir = True,
num_train_epochs = 2,
per_device_train_batch_size = 32,
gradient_accumulation_steps = 2,
save_strategy = "steps",
save_steps = 500,
save_total_limit = 3,
eval_strategy = "epoch",
logging_dir = "logs",
fp16 = torch.cuda.is_available()
)
# Trainer
trainer = Trainer(
model = model,
args = training_args,
train_dataset = tokenised_train,
eval_dataset = tokenised_eval,
data_collator = data_collator,
processing_class = tokeniser
)
# Train
trainer.train()
model.save_pretrained(output_dir)
tokeniser.save_pretrained(output_dir)
print("DalaT5 training complete.")