d-delaurier commited on
Commit
0fcfeda
·
1 Parent(s): 9a2d48a

Create tensorflow_train.py

Browse files
Files changed (1) hide show
  1. tensorflow_train.py +424 -0
tensorflow_train.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+
8
+ os.environ["USE_TF"] = "1"
9
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
10
+
11
+ import datetime
12
+ import hashlib
13
+ import multiprocessing as mp
14
+ import time
15
+
16
+ import numpy as np
17
+ import psutil
18
+ import tensorflow as tf
19
+ from tensorflow.keras import mixed_precision
20
+ from tqdm.auto import tqdm
21
+
22
+ from doctr.models import login_to_hub, push_to_hf_hub
23
+
24
+ gpu_devices = tf.config.experimental.list_physical_devices("GPU")
25
+ if any(gpu_devices):
26
+ tf.config.experimental.set_memory_growth(gpu_devices[0], True)
27
+
28
+ from doctr import transforms as T
29
+ from doctr.datasets import DataLoader, DetectionDataset
30
+ from doctr.models import detection
31
+ from doctr.utils.metrics import LocalizationConfusion
32
+ from utils import EarlyStopper, load_backbone, plot_recorder, plot_samples
33
+
34
+
35
+ def record_lr(
36
+ model: tf.keras.Model,
37
+ train_loader: DataLoader,
38
+ batch_transforms,
39
+ optimizer,
40
+ start_lr: float = 1e-7,
41
+ end_lr: float = 1,
42
+ num_it: int = 100,
43
+ amp: bool = False,
44
+ ):
45
+ """Gridsearch the optimal learning rate for the training.
46
+ Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py
47
+ """
48
+ if num_it > len(train_loader):
49
+ raise ValueError("the value of `num_it` needs to be lower than the number of available batches")
50
+
51
+ # Update param groups & LR
52
+ gamma = (end_lr / start_lr) ** (1 / (num_it - 1))
53
+ optimizer.learning_rate = start_lr
54
+
55
+ lr_recorder = [start_lr * gamma**idx for idx in range(num_it)]
56
+ loss_recorder = []
57
+
58
+ for batch_idx, (images, targets) in enumerate(train_loader):
59
+ images = batch_transforms(images)
60
+
61
+ # Forward, Backward & update
62
+ with tf.GradientTape() as tape:
63
+ train_loss = model(images, targets, training=True)["loss"]
64
+ grads = tape.gradient(train_loss, model.trainable_weights)
65
+
66
+ if amp:
67
+ grads = optimizer.get_unscaled_gradients(grads)
68
+ optimizer.apply_gradients(zip(grads, model.trainable_weights))
69
+
70
+ optimizer.learning_rate = optimizer.learning_rate * gamma
71
+
72
+ # Record
73
+ train_loss = train_loss.numpy()
74
+ if np.any(np.isnan(train_loss)):
75
+ if batch_idx == 0:
76
+ raise ValueError("loss value is NaN or inf.")
77
+ else:
78
+ break
79
+ loss_recorder.append(train_loss.mean())
80
+ # Stop after the number of iterations
81
+ if batch_idx + 1 == num_it:
82
+ break
83
+
84
+ return lr_recorder[: len(loss_recorder)], loss_recorder
85
+
86
+
87
+ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False):
88
+ train_iter = iter(train_loader)
89
+ # Iterate over the batches of the dataset
90
+ pbar = tqdm(train_iter, position=1)
91
+ for images, targets in pbar:
92
+ images = batch_transforms(images)
93
+
94
+ with tf.GradientTape() as tape:
95
+ train_loss = model(images, targets, training=True)["loss"]
96
+ grads = tape.gradient(train_loss, model.trainable_weights)
97
+ if amp:
98
+ grads = optimizer.get_unscaled_gradients(grads)
99
+ optimizer.apply_gradients(zip(grads, model.trainable_weights))
100
+
101
+ pbar.set_description(f"Training loss: {train_loss.numpy():.6}")
102
+
103
+
104
+ def evaluate(model, val_loader, batch_transforms, val_metric):
105
+ # Reset val metric
106
+ val_metric.reset()
107
+ # Validation loop
108
+ val_loss, batch_cnt = 0, 0
109
+ val_iter = iter(val_loader)
110
+ for images, targets in tqdm(val_iter):
111
+ images = batch_transforms(images)
112
+ out = model(images, targets, training=False, return_preds=True)
113
+ # Compute metric
114
+ loc_preds = out["preds"]
115
+ for target, loc_pred in zip(targets, loc_preds):
116
+ for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()):
117
+ if args.rotation and args.eval_straight:
118
+ # Convert pred to boxes [xmin, ymin, xmax, ymax] N, 4, 2 --> N, 4
119
+ boxes_pred = np.concatenate((boxes_pred.min(axis=1), boxes_pred.max(axis=1)), axis=-1)
120
+ val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4])
121
+
122
+ val_loss += out["loss"].numpy()
123
+ batch_cnt += 1
124
+
125
+ val_loss /= batch_cnt
126
+ recall, precision, mean_iou = val_metric.summary()
127
+ return val_loss, recall, precision, mean_iou
128
+
129
+
130
+ def main(args):
131
+ print(args)
132
+
133
+ if args.push_to_hub:
134
+ login_to_hub()
135
+
136
+ if not isinstance(args.workers, int):
137
+ args.workers = min(16, mp.cpu_count())
138
+
139
+ system_available_memory = int(psutil.virtual_memory().available / 1024**3)
140
+
141
+ # AMP
142
+ if args.amp:
143
+ mixed_precision.set_global_policy("mixed_float16")
144
+
145
+ st = time.time()
146
+ val_set = DetectionDataset(
147
+ img_folder=os.path.join(args.val_path, "images"),
148
+ label_path=os.path.join(args.val_path, "labels.json"),
149
+ sample_transforms=T.SampleCompose(
150
+ (
151
+ [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
152
+ if not args.rotation or args.eval_straight
153
+ else []
154
+ )
155
+ + (
156
+ [
157
+ T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
158
+ T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
159
+ T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
160
+ ]
161
+ if args.rotation and not args.eval_straight
162
+ else []
163
+ )
164
+ ),
165
+ use_polygons=args.rotation and not args.eval_straight,
166
+ )
167
+ val_loader = DataLoader(
168
+ val_set,
169
+ batch_size=args.batch_size,
170
+ shuffle=False,
171
+ drop_last=False,
172
+ num_workers=args.workers,
173
+ )
174
+ print(
175
+ f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
176
+ f"{val_loader.num_batches} batches)"
177
+ )
178
+ with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
179
+ val_hash = hashlib.sha256(f.read()).hexdigest()
180
+
181
+ batch_transforms = T.Compose([
182
+ T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)),
183
+ ])
184
+
185
+ # Load doctr model
186
+ model = detection.__dict__[args.arch](
187
+ pretrained=args.pretrained,
188
+ input_shape=(args.input_size, args.input_size, 3),
189
+ assume_straight_pages=not args.rotation,
190
+ class_names=val_set.class_names,
191
+ )
192
+
193
+ # Resume weights
194
+ if isinstance(args.resume, str):
195
+ model.load_weights(args.resume)
196
+
197
+ if isinstance(args.pretrained_backbone, str):
198
+ print("Loading backbone weights.")
199
+ model = load_backbone(model, args.pretrained_backbone)
200
+ print("Done.")
201
+
202
+ # Metrics
203
+ val_metric = LocalizationConfusion(
204
+ use_polygons=args.rotation and not args.eval_straight,
205
+ mask_shape=(args.input_size, args.input_size),
206
+ use_broadcasting=True if system_available_memory > 62 else False,
207
+ )
208
+ if args.test_only:
209
+ print("Running evaluation")
210
+ val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
211
+ print(
212
+ f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | "
213
+ f"Mean IoU: {mean_iou:.2%})"
214
+ )
215
+ return
216
+
217
+ st = time.time()
218
+ # Load both train and val data generators
219
+ train_set = DetectionDataset(
220
+ img_folder=os.path.join(args.train_path, "images"),
221
+ label_path=os.path.join(args.train_path, "labels.json"),
222
+ img_transforms=T.Compose([
223
+ # Augmentations
224
+ T.RandomApply(T.ColorInversion(), 0.1),
225
+ T.RandomJpegQuality(60),
226
+ T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
227
+ T.RandomApply(T.RandomShadow(), 0.4),
228
+ T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3),
229
+ T.RandomSaturation(0.3),
230
+ T.RandomContrast(0.3),
231
+ T.RandomBrightness(0.3),
232
+ T.RandomApply(T.ToGray(num_output_channels=3), 0.1),
233
+ ]),
234
+ sample_transforms=T.SampleCompose(
235
+ (
236
+ [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
237
+ if not args.rotation
238
+ else []
239
+ )
240
+ + (
241
+ [
242
+ T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
243
+ T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
244
+ T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
245
+ ]
246
+ if args.rotation
247
+ else []
248
+ )
249
+ ),
250
+ use_polygons=args.rotation,
251
+ )
252
+ train_loader = DataLoader(
253
+ train_set,
254
+ batch_size=args.batch_size,
255
+ shuffle=True,
256
+ drop_last=True,
257
+ num_workers=args.workers,
258
+ )
259
+ print(
260
+ f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
261
+ f"{train_loader.num_batches} batches)"
262
+ )
263
+ with open(os.path.join(args.train_path, "labels.json"), "rb") as f:
264
+ train_hash = hashlib.sha256(f.read()).hexdigest()
265
+
266
+ if args.show_samples:
267
+ x, target = next(iter(train_loader))
268
+ plot_samples(x, target)
269
+ return
270
+
271
+ # Optimizer
272
+ scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
273
+ args.lr,
274
+ decay_steps=args.epochs * len(train_loader),
275
+ decay_rate=1 / (25e4), # final lr as a fraction of initial lr
276
+ staircase=False,
277
+ name="ExponentialDecay",
278
+ )
279
+ optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5)
280
+ if args.amp:
281
+ optimizer = mixed_precision.LossScaleOptimizer(optimizer)
282
+ # LR Finder
283
+ if args.find_lr:
284
+ lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp)
285
+ plot_recorder(lrs, losses)
286
+ return
287
+
288
+ # Tensorboard to monitor training
289
+ current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
290
+ exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name
291
+
292
+ config = {
293
+ "learning_rate": args.lr,
294
+ "epochs": args.epochs,
295
+ "batch_size": args.batch_size,
296
+ "architecture": args.arch,
297
+ "input_size": args.input_size,
298
+ "optimizer": optimizer.name,
299
+ "framework": "tensorflow",
300
+ "scheduler": scheduler.name,
301
+ "train_hash": train_hash,
302
+ "val_hash": val_hash,
303
+ "pretrained": args.pretrained,
304
+ "rotation": args.rotation,
305
+ }
306
+
307
+ # W&B
308
+ if args.wb:
309
+ import wandb
310
+
311
+ run = wandb.init(name=exp_name, project="text-detection", config=config)
312
+
313
+ # ClearML
314
+ if args.clearml:
315
+ from clearml import Task
316
+
317
+ task = Task.init(project_name="docTR/text-detection", task_name=exp_name, reuse_last_task_id=False)
318
+ task.upload_artifact("config", config)
319
+
320
+ if args.freeze_backbone:
321
+ for layer in model.feat_extractor.layers:
322
+ layer.trainable = False
323
+
324
+ min_loss = np.inf
325
+ if args.early_stop:
326
+ early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
327
+
328
+ # Training loop
329
+ for epoch in range(args.epochs):
330
+ fit_one_epoch(model, train_loader, batch_transforms, optimizer, args.amp)
331
+ # Validation loop at the end of each epoch
332
+ val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
333
+ if val_loss < min_loss:
334
+ print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
335
+ model.save_weights(f"./{exp_name}/weights")
336
+ min_loss = val_loss
337
+ log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
338
+ if any(val is None for val in (recall, precision, mean_iou)):
339
+ log_msg += "(Undefined metric value, caused by empty GTs or predictions)"
340
+ else:
341
+ log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})"
342
+ print(log_msg)
343
+ # W&B
344
+ if args.wb:
345
+ wandb.log({
346
+ "val_loss": val_loss,
347
+ "recall": recall,
348
+ "precision": precision,
349
+ "mean_iou": mean_iou,
350
+ })
351
+
352
+ # ClearML
353
+ if args.clearml:
354
+ from clearml import Logger
355
+
356
+ logger = Logger.current_logger()
357
+ logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch)
358
+ logger.report_scalar(title="Precision Recall", series="recall", value=recall, iteration=epoch)
359
+ logger.report_scalar(title="Precision Recall", series="precision", value=precision, iteration=epoch)
360
+ logger.report_scalar(title="Mean IoU", series="mean_iou", value=mean_iou, iteration=epoch)
361
+ if args.early_stop and early_stopper.early_stop(val_loss):
362
+ print("Training halted early due to reaching patience limit.")
363
+ break
364
+ if args.wb:
365
+ run.finish()
366
+
367
+ if args.push_to_hub:
368
+ push_to_hf_hub(model, exp_name, task="detection", run_config=args)
369
+
370
+
371
+ def parse_args():
372
+ import argparse
373
+
374
+ parser = argparse.ArgumentParser(
375
+ description="DocTR training script for text detection (TensorFlow)",
376
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
377
+ )
378
+
379
+ parser.add_argument("arch", type=str, help="text-detection model to train")
380
+ parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
381
+ parser.add_argument("--val_path", type=str, help="path to validation data folder")
382
+ parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
383
+ parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
384
+ parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training")
385
+ parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W")
386
+ parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam)")
387
+ parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading")
388
+ parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
389
+ parser.add_argument("--pretrained-backbone", type=str, default=None, help="Path to your backbone weights")
390
+ parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
391
+ parser.add_argument(
392
+ "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
393
+ )
394
+ parser.add_argument(
395
+ "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
396
+ )
397
+ parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases")
398
+ parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML")
399
+ parser.add_argument("--push-to-hub", dest="push_to_hub", action="store_true", help="Push to Huggingface Hub")
400
+ parser.add_argument(
401
+ "--pretrained",
402
+ dest="pretrained",
403
+ action="store_true",
404
+ help="Load pretrained parameters before starting the training",
405
+ )
406
+ parser.add_argument("--rotation", dest="rotation", action="store_true", help="train with rotated documents")
407
+ parser.add_argument(
408
+ "--eval-straight",
409
+ action="store_true",
410
+ help="metrics evaluation with straight boxes instead of polygons to save time + memory",
411
+ )
412
+ parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
413
+ parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR")
414
+ parser.add_argument("--early-stop", action="store_true", help="Enable early stopping")
415
+ parser.add_argument("--early-stop-epochs", type=int, default=5, help="Patience for early stopping")
416
+ parser.add_argument("--early-stop-delta", type=float, default=0.01, help="Minimum Delta for early stopping")
417
+ args = parser.parse_args()
418
+
419
+ return args
420
+
421
+
422
+ if __name__ == "__main__":
423
+ args = parse_args()
424
+ main(args)