MasteredUltraInstinct commited on
Commit
bdd2112
Β·
verified Β·
1 Parent(s): 326ef6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -7
app.py CHANGED
@@ -1,26 +1,39 @@
1
  from datasets import load_dataset
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
3
 
4
- # Load the handwritten math dataset
5
  ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:1000]")
6
 
 
7
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
 
8
 
 
9
  def preprocess(ex):
10
  img = ex["image"].convert("RGB")
11
  inputs = processor(images=img, return_tensors="pt")
12
- labels = processor.tokenizer(ex["label"], truncation=True, padding="max_length", max_length=128).input_ids
 
 
 
 
 
 
 
 
 
13
  ex["pixel_values"] = inputs.pixel_values[0]
14
  ex["labels"] = labels
15
  return ex
16
 
 
17
  ds = ds.map(preprocess, remove_columns=["image", "label"])
18
 
19
-
20
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
21
  model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
22
  model.config.pad_token_id = processor.tokenizer.pad_token_id
23
 
 
24
  training_args = Seq2SeqTrainingArguments(
25
  output_dir="trained_model",
26
  per_device_train_batch_size=2,
@@ -32,6 +45,7 @@ training_args = Seq2SeqTrainingArguments(
32
  push_to_hub=False,
33
  )
34
 
 
35
  trainer = Seq2SeqTrainer(
36
  model=model,
37
  args=training_args,
@@ -40,6 +54,10 @@ trainer = Seq2SeqTrainer(
40
  data_collator=default_data_collator,
41
  )
42
 
43
- trainer.train()
44
- model.save_pretrained("trained_model")
45
- processor.save_pretrained("trained_model")
 
 
 
 
 
1
  from datasets import load_dataset
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
3
 
4
+ # Load the handwritten math dataset (1000 examples)
5
  ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:1000]")
6
 
7
+ # Load processor and model
8
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
9
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
10
 
11
+ # Preprocess function
12
  def preprocess(ex):
13
  img = ex["image"].convert("RGB")
14
  inputs = processor(images=img, return_tensors="pt")
15
+
16
+ # Convert label index to actual LaTeX string
17
+ label_str = ds.features["label"].int2str(ex["label"])
18
+ labels = processor.tokenizer(
19
+ label_str,
20
+ truncation=True,
21
+ padding="max_length",
22
+ max_length=128
23
+ ).input_ids
24
+
25
  ex["pixel_values"] = inputs.pixel_values[0]
26
  ex["labels"] = labels
27
  return ex
28
 
29
+ # Apply preprocessing
30
  ds = ds.map(preprocess, remove_columns=["image", "label"])
31
 
32
+ # Model config
 
33
  model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
34
  model.config.pad_token_id = processor.tokenizer.pad_token_id
35
 
36
+ # Training arguments
37
  training_args = Seq2SeqTrainingArguments(
38
  output_dir="trained_model",
39
  per_device_train_batch_size=2,
 
45
  push_to_hub=False,
46
  )
47
 
48
+ # Trainer
49
  trainer = Seq2SeqTrainer(
50
  model=model,
51
  args=training_args,
 
54
  data_collator=default_data_collator,
55
  )
56
 
57
+ # Train and save
58
+ if __name__ == "__main__":
59
+ print("πŸš€ Training started")
60
+ trainer.train()
61
+ print("βœ… Training completed")
62
+ model.save_pretrained("trained_model")
63
+ processor.save_pretrained("trained_model")