court-records-htr / train_trocr.py
MikkoLipsanen's picture
Update train_trocr.py
71042e0 verified
import pandas as pd
import torch
from PIL import Image
import argparse
from evaluate import load
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
import torchvision.transforms as transforms
from dataset import TextlineDataset
parser = argparse.ArgumentParser('arguments for the code')
parser.add_argument('--root_path', type=str, default="",
help='Root path to data files.')
parser.add_argument('--tr_data_path', type=str, default="/path/to/train/data.csv",
help='Path to .csv file containing the training data.')
parser.add_argument('--val_data_path', type=str, default="/path/to/val/data.csv",
help='Path to .csv file containing the validation data.')
parser.add_argument('--output_path', type=str, default="/output/path/",
help='Path for saving training results.')
parser.add_argument('--batch_size', type=int, default=24,
help='Batch size per device.')
parser.add_argument('--epochs', type=int, default=13,
help='Number of training epochs.')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: ', device)
# Initialize processor and model
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.to(device)
# Initialize metrics
cer_metric = load("cer")
wer_metric = load("wer")
# Load train and validation data to dataframes
train_df = pd.read_csv(args.tr_data_path)
val_df = pd.read_csv(args.val_data_path)
# Reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)
# Create train and validation datasets
train_dataset = TextlineDataset(root_dir=args.root_path,
df=train_df,
processor=processor,
augment=False)
eval_dataset = TextlineDataset(root_dir=args.root_path,
df=val_df,
processor=processor,
augment=False)
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))
# Define model configuration
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
# Set arguments for model training
# For all arguments see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=50,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
load_best_model_at_end=True,
metric_for_best_model='cer',
greater_is_better=False,
#fp16=True,
num_train_epochs=args.epochs,
save_total_limit=2,
output_dir=args.output_path,
optim="adamw_torch"
)
# Function for computing CER and WER metrics for the prediction results
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer, "wer": wer}
# Instantiate trainer
# For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
# Train the model
trainer.train()
#trainer.train(resume_from_checkpoint = True)
model.save_pretrained(args.output_path)
processor.save_pretrained(args.output_path + "/processor")