alexandraroze commited on
Commit
b265c62
·
1 Parent(s): 50bd1fc

fixed config

Browse files
Files changed (3) hide show
  1. requirements.txt +9 -0
  2. train_byol.py +18 -31
  3. train_cross_classifier.py +22 -35
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch~=2.6.0
2
+ numpy~=2.2.3
3
+ torchvision~=0.21.0
4
+ pillow~=11.1.0
5
+ streamlit~=1.42.0
6
+ wandb~=0.19.6
7
+ tqdm~=4.67.1
8
+ typing_extensions~=4.12.2
9
+ matplotlib~=3.10.0
train_byol.py CHANGED
@@ -261,41 +261,28 @@ def main(config: dict):
261
 
262
 
263
  if __name__ == "__main__":
264
- # parser = argparse.ArgumentParser(description="Train BYOL model")
265
- # parser.add_argument("--batch_size", type=int, default=512)
266
- # parser.add_argument("--lr", type=float, default=5e-4)
267
- # parser.add_argument("--num_epochs", type=int, default=15)
268
- # parser.add_argument("--num_train_samples", type=int, default=100000)
269
- # parser.add_argument("--num_val_samples", type=int, default=10000)
270
- # parser.add_argument("--random_intensity", type=int, default=1)
271
- # parser.add_argument("--early_stopping_patience", type=int, default=3)
272
- # parser.add_argument("--save_path", type=str, default="best_byol.pth")
273
- # args = parser.parse_args()
274
-
275
- # config = {
276
- # "batch_size": args.batch_size,
277
- # "lr": args.lr,
278
- # "num_epochs": args.num_epochs,
279
- # "num_train_samples": args.num_train_samples,
280
- # "num_val_samples": args.num_val_samples,
281
- # "shape_params": {
282
- # "random_intensity": bool(args.random_intensity)
283
- # },
284
- # "early_stopping_patience": args.early_stopping_patience,
285
- # "save_path": args.save_path
286
- # }
287
 
288
  config = {
289
- "batch_size": 1024,
290
- "lr": 1e-3,
291
- "num_epochs": 15,
292
- "num_train_samples": 100000,
293
- "num_val_samples": 10000,
294
  "shape_params": {
295
- "random_intensity": True
296
  },
297
- "early_stopping_patience": 3,
298
- "save_path": "best_byol.pth"
299
  }
300
 
301
  main(config)
 
261
 
262
 
263
  if __name__ == "__main__":
264
+ parser = argparse.ArgumentParser(description="Train BYOL model")
265
+ parser.add_argument("--batch_size", type=int, default=512)
266
+ parser.add_argument("--lr", type=float, default=5e-4)
267
+ parser.add_argument("--num_epochs", type=int, default=15)
268
+ parser.add_argument("--num_train_samples", type=int, default=100000)
269
+ parser.add_argument("--num_val_samples", type=int, default=10000)
270
+ parser.add_argument("--random_intensity", type=int, default=1)
271
+ parser.add_argument("--early_stopping_patience", type=int, default=3)
272
+ parser.add_argument("--save_path", type=str, default="best_byol.pth")
273
+ args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  config = {
276
+ "batch_size": args.batch_size,
277
+ "lr": args.lr,
278
+ "num_epochs": args.num_epochs,
279
+ "num_train_samples": args.num_train_samples,
280
+ "num_val_samples": args.num_val_samples,
281
  "shape_params": {
282
+ "random_intensity": bool(args.random_intensity)
283
  },
284
+ "early_stopping_patience": args.early_stopping_patience,
285
+ "save_path": args.save_path
286
  }
287
 
288
  main(config)
train_cross_classifier.py CHANGED
@@ -250,43 +250,30 @@ def main(config):
250
 
251
  if __name__ == "__main__":
252
 
253
- # parser = argparse.ArgumentParser(description="Train classifier model")
254
- # parser.add_argument("--path_to_encoder", type=str, default="best_byol.pth")
255
- # parser.add_argument("--batch_size", type=int, default=256)
256
- # parser.add_argument("--lr", type=float, default=8e-5)
257
- # parser.add_argument("--weight_decay", type=float, default=1e-4)
258
- # parser.add_argument("--step_size", type=int, default=10)
259
- # parser.add_argument("--gamma", type=float, default=0.1)
260
- # parser.add_argument("--num_epochs", type=int, default=10)
261
- # parser.add_argument("--num_train_samples", type=int, default=10000)
262
- # parser.add_argument("--num_val_samples", type=int, default=2000)
263
- # parser.add_argument("--save_path", type=str, default="best_attention_classifier.pth")
264
- # args = parser.parse_args()
265
-
266
- # config = {
267
- # "path_to_encoder": args.path_to_encoder,
268
- # "batch_size": args.batch_size,
269
- # "lr": args.lr,
270
- # "weight_decay": args.weight_decay,
271
- # "step_size": args.step_size,
272
- # "gamma": args.gamma,
273
- # "num_epochs": args.num_epochs,
274
- # "num_train_samples": args.num_train_samples,
275
- # "num_val_samples": args.num_val_samples,
276
- # "save_path": args.save_path,
277
- # }
278
 
279
  config = {
280
- "path_to_encoder": "best_byol.pth",
281
- "batch_size": 256,
282
- "lr": 8e-5,
283
- "weight_decay": 1e-4,
284
- "step_size": 10,
285
- "gamma": 0.1,
286
- "num_epochs": 10,
287
- "num_train_samples": 10000,
288
- "num_val_samples": 2000,
289
- "save_path": "best_attention_classifier.pth",
290
  }
291
 
292
  if "shape_params" not in config:
 
250
 
251
  if __name__ == "__main__":
252
 
253
+ parser = argparse.ArgumentParser(description="Train classifier model")
254
+ parser.add_argument("--path_to_encoder", type=str, default="best_byol.pth")
255
+ parser.add_argument("--batch_size", type=int, default=256)
256
+ parser.add_argument("--lr", type=float, default=8e-5)
257
+ parser.add_argument("--weight_decay", type=float, default=1e-4)
258
+ parser.add_argument("--step_size", type=int, default=10)
259
+ parser.add_argument("--gamma", type=float, default=0.1)
260
+ parser.add_argument("--num_epochs", type=int, default=10)
261
+ parser.add_argument("--num_train_samples", type=int, default=10000)
262
+ parser.add_argument("--num_val_samples", type=int, default=2000)
263
+ parser.add_argument("--save_path", type=str, default="best_attention_classifier.pth")
264
+ args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  config = {
267
+ "path_to_encoder": args.path_to_encoder,
268
+ "batch_size": args.batch_size,
269
+ "lr": args.lr,
270
+ "weight_decay": args.weight_decay,
271
+ "step_size": args.step_size,
272
+ "gamma": args.gamma,
273
+ "num_epochs": args.num_epochs,
274
+ "num_train_samples": args.num_train_samples,
275
+ "num_val_samples": args.num_val_samples,
276
+ "save_path": args.save_path,
277
  }
278
 
279
  if "shape_params" not in config: