Spaces:
Sleeping
Sleeping
Commit
·
b265c62
1
Parent(s):
50bd1fc
fixed config
Browse files- requirements.txt +9 -0
- train_byol.py +18 -31
- train_cross_classifier.py +22 -35
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch~=2.6.0
|
2 |
+
numpy~=2.2.3
|
3 |
+
torchvision~=0.21.0
|
4 |
+
pillow~=11.1.0
|
5 |
+
streamlit~=1.42.0
|
6 |
+
wandb~=0.19.6
|
7 |
+
tqdm~=4.67.1
|
8 |
+
typing_extensions~=4.12.2
|
9 |
+
matplotlib~=3.10.0
|
train_byol.py
CHANGED
@@ -261,41 +261,28 @@ def main(config: dict):
|
|
261 |
|
262 |
|
263 |
if __name__ == "__main__":
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
# config = {
|
276 |
-
# "batch_size": args.batch_size,
|
277 |
-
# "lr": args.lr,
|
278 |
-
# "num_epochs": args.num_epochs,
|
279 |
-
# "num_train_samples": args.num_train_samples,
|
280 |
-
# "num_val_samples": args.num_val_samples,
|
281 |
-
# "shape_params": {
|
282 |
-
# "random_intensity": bool(args.random_intensity)
|
283 |
-
# },
|
284 |
-
# "early_stopping_patience": args.early_stopping_patience,
|
285 |
-
# "save_path": args.save_path
|
286 |
-
# }
|
287 |
|
288 |
config = {
|
289 |
-
"batch_size":
|
290 |
-
"lr":
|
291 |
-
"num_epochs":
|
292 |
-
"num_train_samples":
|
293 |
-
"num_val_samples":
|
294 |
"shape_params": {
|
295 |
-
"random_intensity":
|
296 |
},
|
297 |
-
"early_stopping_patience":
|
298 |
-
"save_path":
|
299 |
}
|
300 |
|
301 |
main(config)
|
|
|
261 |
|
262 |
|
263 |
if __name__ == "__main__":
|
264 |
+
parser = argparse.ArgumentParser(description="Train BYOL model")
|
265 |
+
parser.add_argument("--batch_size", type=int, default=512)
|
266 |
+
parser.add_argument("--lr", type=float, default=5e-4)
|
267 |
+
parser.add_argument("--num_epochs", type=int, default=15)
|
268 |
+
parser.add_argument("--num_train_samples", type=int, default=100000)
|
269 |
+
parser.add_argument("--num_val_samples", type=int, default=10000)
|
270 |
+
parser.add_argument("--random_intensity", type=int, default=1)
|
271 |
+
parser.add_argument("--early_stopping_patience", type=int, default=3)
|
272 |
+
parser.add_argument("--save_path", type=str, default="best_byol.pth")
|
273 |
+
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
config = {
|
276 |
+
"batch_size": args.batch_size,
|
277 |
+
"lr": args.lr,
|
278 |
+
"num_epochs": args.num_epochs,
|
279 |
+
"num_train_samples": args.num_train_samples,
|
280 |
+
"num_val_samples": args.num_val_samples,
|
281 |
"shape_params": {
|
282 |
+
"random_intensity": bool(args.random_intensity)
|
283 |
},
|
284 |
+
"early_stopping_patience": args.early_stopping_patience,
|
285 |
+
"save_path": args.save_path
|
286 |
}
|
287 |
|
288 |
main(config)
|
train_cross_classifier.py
CHANGED
@@ -250,43 +250,30 @@ def main(config):
|
|
250 |
|
251 |
if __name__ == "__main__":
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
# config = {
|
267 |
-
# "path_to_encoder": args.path_to_encoder,
|
268 |
-
# "batch_size": args.batch_size,
|
269 |
-
# "lr": args.lr,
|
270 |
-
# "weight_decay": args.weight_decay,
|
271 |
-
# "step_size": args.step_size,
|
272 |
-
# "gamma": args.gamma,
|
273 |
-
# "num_epochs": args.num_epochs,
|
274 |
-
# "num_train_samples": args.num_train_samples,
|
275 |
-
# "num_val_samples": args.num_val_samples,
|
276 |
-
# "save_path": args.save_path,
|
277 |
-
# }
|
278 |
|
279 |
config = {
|
280 |
-
"path_to_encoder":
|
281 |
-
"batch_size":
|
282 |
-
"lr":
|
283 |
-
"weight_decay":
|
284 |
-
"step_size":
|
285 |
-
"gamma":
|
286 |
-
"num_epochs":
|
287 |
-
"num_train_samples":
|
288 |
-
"num_val_samples":
|
289 |
-
"save_path":
|
290 |
}
|
291 |
|
292 |
if "shape_params" not in config:
|
|
|
250 |
|
251 |
if __name__ == "__main__":
|
252 |
|
253 |
+
parser = argparse.ArgumentParser(description="Train classifier model")
|
254 |
+
parser.add_argument("--path_to_encoder", type=str, default="best_byol.pth")
|
255 |
+
parser.add_argument("--batch_size", type=int, default=256)
|
256 |
+
parser.add_argument("--lr", type=float, default=8e-5)
|
257 |
+
parser.add_argument("--weight_decay", type=float, default=1e-4)
|
258 |
+
parser.add_argument("--step_size", type=int, default=10)
|
259 |
+
parser.add_argument("--gamma", type=float, default=0.1)
|
260 |
+
parser.add_argument("--num_epochs", type=int, default=10)
|
261 |
+
parser.add_argument("--num_train_samples", type=int, default=10000)
|
262 |
+
parser.add_argument("--num_val_samples", type=int, default=2000)
|
263 |
+
parser.add_argument("--save_path", type=str, default="best_attention_classifier.pth")
|
264 |
+
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
config = {
|
267 |
+
"path_to_encoder": args.path_to_encoder,
|
268 |
+
"batch_size": args.batch_size,
|
269 |
+
"lr": args.lr,
|
270 |
+
"weight_decay": args.weight_decay,
|
271 |
+
"step_size": args.step_size,
|
272 |
+
"gamma": args.gamma,
|
273 |
+
"num_epochs": args.num_epochs,
|
274 |
+
"num_train_samples": args.num_train_samples,
|
275 |
+
"num_val_samples": args.num_val_samples,
|
276 |
+
"save_path": args.save_path,
|
277 |
}
|
278 |
|
279 |
if "shape_params" not in config:
|