|
import pandas as pd |
|
import torch |
|
from PIL import Image |
|
import argparse |
|
from evaluate import load |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW |
|
import torchvision.transforms as transforms |
|
|
|
from dataset import TextlineDataset |
|
|
|
parser = argparse.ArgumentParser('arguments for the code') |
|
|
|
parser.add_argument('--root_path', type=str, default="", |
|
help='Root path to data files.') |
|
parser.add_argument('--tr_data_path', type=str, default="/path/to/train/data.csv", |
|
help='Path to .csv file containing the training data.') |
|
parser.add_argument('--val_data_path', type=str, default="/path/to/val/data.csv", |
|
help='Path to .csv file containing the validation data.') |
|
parser.add_argument('--output_path', type=str, default="/output/path/", |
|
help='Path for saving training results.') |
|
parser.add_argument('--batch_size', type=int, default=24, |
|
help='Batch size per device.') |
|
parser.add_argument('--epochs', type=int, default=13, |
|
help='Number of training epochs.') |
|
|
|
args = parser.parse_args() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print('Device: ', device) |
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
|
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") |
|
model.to(device) |
|
|
|
|
|
cer_metric = load("cer") |
|
wer_metric = load("wer") |
|
|
|
|
|
train_df = pd.read_csv(args.tr_data_path) |
|
val_df = pd.read_csv(args.val_data_path) |
|
|
|
|
|
train_df.reset_index(drop=True, inplace=True) |
|
val_df.reset_index(drop=True, inplace=True) |
|
|
|
|
|
train_dataset = TextlineDataset(root_dir=args.root_path, |
|
df=train_df, |
|
processor=processor, |
|
augment=False) |
|
|
|
eval_dataset = TextlineDataset(root_dir=args.root_path, |
|
df=val_df, |
|
processor=processor, |
|
augment=False) |
|
|
|
print("Number of training examples:", len(train_dataset)) |
|
print("Number of validation examples:", len(eval_dataset)) |
|
|
|
|
|
|
|
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
model.config.vocab_size = model.config.decoder.vocab_size |
|
|
|
model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
model.config.max_length = 64 |
|
model.config.early_stopping = True |
|
model.config.no_repeat_ngram_size = 3 |
|
model.config.length_penalty = 2.0 |
|
model.config.num_beams = 4 |
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
logging_strategy="steps", |
|
logging_steps=50, |
|
per_device_train_batch_size=args.batch_size, |
|
per_device_eval_batch_size=args.batch_size, |
|
load_best_model_at_end=True, |
|
metric_for_best_model='cer', |
|
greater_is_better=False, |
|
|
|
num_train_epochs=args.epochs, |
|
save_total_limit=2, |
|
output_dir=args.output_path, |
|
optim="adamw_torch" |
|
) |
|
|
|
|
|
def compute_metrics(pred): |
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
|
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
wer = wer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"cer": cer, "wer": wer} |
|
|
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
tokenizer=processor.feature_extractor, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
data_collator=default_data_collator, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
model.save_pretrained(args.output_path) |
|
processor.save_pretrained(args.output_path + "/processor") |