Spaces:
Runtime error
Runtime error
File size: 1,808 Bytes
32cd99d e5bfb7c 32cd99d e5bfb7c bdd2112 32cd99d bdd2112 32cd99d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
from datasets import load_dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
if os.path.exists("trained_model"):
print("β
Model already exists. Skipping training.")
else:
print("π Starting training...")
ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
def preprocess(ex):
img = ex["image"].convert("RGB")
inputs = processor(images=img, return_tensors="pt")
labels = processor.tokenizer(ex["label"], truncation=True, padding="max_length", max_length=128).input_ids
ex["pixel_values"] = inputs.pixel_values[0]
ex["labels"] = labels
return ex
ds = ds.map(preprocess, remove_columns=["image", "label"])
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
training_args = Seq2SeqTrainingArguments(
output_dir="trained_model",
per_device_train_batch_size=2,
num_train_epochs=1,
learning_rate=5e-5,
logging_steps=10,
save_steps=500,
fp16=False,
push_to_hub=False,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=ds,
tokenizer=processor.tokenizer,
data_collator=default_data_collator,
)
trainer.train()
print("β
Training completed")
model.save_pretrained("trained_model")
processor.save_pretrained("trained_model")
print("β
Model saved to trained_model/")
|