PolSOL / train.py
MasteredUltraInstinct's picture
Update train.py
32cd99d verified
raw
history blame
1.81 kB
import os
from datasets import load_dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
if os.path.exists("trained_model"):
print("βœ… Model already exists. Skipping training.")
else:
print("πŸš€ Starting training...")
ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
def preprocess(ex):
img = ex["image"].convert("RGB")
inputs = processor(images=img, return_tensors="pt")
labels = processor.tokenizer(ex["label"], truncation=True, padding="max_length", max_length=128).input_ids
ex["pixel_values"] = inputs.pixel_values[0]
ex["labels"] = labels
return ex
ds = ds.map(preprocess, remove_columns=["image", "label"])
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
training_args = Seq2SeqTrainingArguments(
output_dir="trained_model",
per_device_train_batch_size=2,
num_train_epochs=1,
learning_rate=5e-5,
logging_steps=10,
save_steps=500,
fp16=False,
push_to_hub=False,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=ds,
tokenizer=processor.tokenizer,
data_collator=default_data_collator,
)
trainer.train()
print("βœ… Training completed")
model.save_pretrained("trained_model")
processor.save_pretrained("trained_model")
print("βœ… Model saved to trained_model/")