Pranjal2041
commited on
Commit
·
970a7a2
1
Parent(s):
031b2c5
Initial demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .gitignore +2 -0
- DenseMammogram/.gitignore +2 -0
- DenseMammogram/README.md +54 -0
- DenseMammogram/advanced_config.py +36 -0
- DenseMammogram/advanced_logger.py +48 -0
- DenseMammogram/all_graphs.py +156 -0
- DenseMammogram/auc_by_pranjal.py +120 -0
- DenseMammogram/dataloaders.py +259 -0
- DenseMammogram/detection/README.md +81 -0
- DenseMammogram/detection/coco_eval.py +191 -0
- DenseMammogram/detection/coco_utils.py +249 -0
- DenseMammogram/detection/engine.py +276 -0
- DenseMammogram/detection/group_by_aspect_ratio.py +196 -0
- DenseMammogram/detection/presets.py +47 -0
- DenseMammogram/detection/train.py +269 -0
- DenseMammogram/detection/transforms.py +283 -0
- DenseMammogram/detection/utils.py +282 -0
- DenseMammogram/ensemble_boxes/__init__.py +9 -0
- DenseMammogram/ensemble_boxes/ensemble_boxes_nms.py +249 -0
- DenseMammogram/ensemble_boxes/ensemble_boxes_nmw.py +202 -0
- DenseMammogram/ensemble_boxes/ensemble_boxes_wbf.py +269 -0
- DenseMammogram/ensemble_boxes/ensemble_boxes_wbf_3d.py +222 -0
- DenseMammogram/experimenter.py +213 -0
- DenseMammogram/froc_by_pranjal.py +236 -0
- DenseMammogram/geenerate_aiims.py +61 -0
- DenseMammogram/geenerate_ddsm_preds.py +61 -0
- DenseMammogram/geenerate_inbreast_preds.py +57 -0
- DenseMammogram/geenerate_irch.py +62 -0
- DenseMammogram/merge_predictions.py +152 -0
- DenseMammogram/model_utils.py +83 -0
- DenseMammogram/models.py +201 -0
- DenseMammogram/plot_froc.py +43 -0
- DenseMammogram/requirements.txt +11 -0
- DenseMammogram/train_bilateral.py +47 -0
- DenseMammogram/train_frcnn.py +34 -0
- DenseMammogram/utils.py +41 -0
- app.py +117 -0
- img_out1.jpg +3 -0
- img_out2.jpg +3 -0
- model.py +57 -0
- pretrained_models/AIIMS_C1/frcnn_models/frcnn_model.pth +3 -0
- pretrained_models/AIIMS_C2/frcnn_models/frcnn_model.pth +3 -0
- pretrained_models/AIIMS_C3/frcnn_models/frcnn_model.pth +3 -0
- pretrained_models/AIIMS_C4/frcnn_models/frcnn_model.pth +3 -0
- pretrained_models/AIIMS_T1/frcnn_models/frcnn_model.pth +3 -0
- pretrained_models/AIIMS_T2/frcnn_models/frcnn_model.pth +3 -0
- pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth +3 -0
- pretrained_models/frcnn/frcnn_models/frcnn_model.pth +3 -0
- requirements.txt +12 -0
.gitattributes
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
@@ -32,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
10 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
11 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
12 |
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
pretrained_models filter=lfs diff=lfs merge=lfs -text
|
37 |
+
sample_images filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Demo.ipynb
|
2 |
+
__pycache__
|
DenseMammogram/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
pretrained_models
|
2 |
+
__pycache__
|
DenseMammogram/README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deep Learning for Detection of Iso-Sense, Obscure Masses in Mammographically Dense Breasts
|
2 |
+
[![report](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/)
|
3 |
+
|
4 |
+
## Introduction
|
5 |
+
Deep Learning for Detection of Iso-Sense, Obscure Masses in Mammographically Dense Breasts is a paper on object detection method for finding malignant masses in breast mammograms. Our model is particularly useful for dense breasts and iso-dense and obscure masses. In this paper we have included code and pretrained weights for the paper along with all the scripts to replicate numbers in the paper(Our private dataset is not included).
|
6 |
+
|
7 |
+
## Getting Started
|
8 |
+
|
9 |
+
|
10 |
+
First clone the repo:
|
11 |
+
```bash
|
12 |
+
git clone https://github.com/Pranjal2041/DenseMammograms.git
|
13 |
+
```
|
14 |
+
|
15 |
+
Next setup the enviornment using `conda` or `virtualenv`:
|
16 |
+
```bash
|
17 |
+
1. conda create -n densebreast python=3.7
|
18 |
+
conda activate densebreast
|
19 |
+
pip install -r requirements.txt
|
20 |
+
|
21 |
+
or
|
22 |
+
|
23 |
+
2. python -m venv densebreast
|
24 |
+
source densebreast/bin/activate
|
25 |
+
pip install -r requirements.txt
|
26 |
+
```
|
27 |
+
|
28 |
+
## Pretrained Weights
|
29 |
+
|
30 |
+
You can download the pretrained models from this [url](https://csciitd-my.sharepoint.com/:f:/g/personal/cs5190443_iitd_ac_in/ElTbduIuI49EougSH05Tb4IBhbc5gXCrlok0X_xvAI196g?e=Ss2eS1) in the current directory.
|
31 |
+
<br>
|
32 |
+
|
33 |
+
## Running the Code
|
34 |
+
|
35 |
+
To generate predictions and FROC graphs using the pretrained models, run:
|
36 |
+
`python all_graphs.py`
|
37 |
+
|
38 |
+
For running individual models on other datasets, geenerate_{dataset}_preds.py have been provided.
|
39 |
+
For example to run predictions on inbreast, run:
|
40 |
+
`python geenerate_inbreast_preds.py`
|
41 |
+
|
42 |
+
|
43 |
+
## Demo
|
44 |
+
|
45 |
+
You can either use **Google Colab Demo** or **Huggingface demo**
|
46 |
+
|
47 |
+
## Citation
|
48 |
+
|
49 |
+
Details Coming Soon!
|
50 |
+
|
51 |
+
## License
|
52 |
+
|
53 |
+
TODO: Add License
|
54 |
+
|
DenseMammogram/advanced_config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
class AdvancedConfig:
|
5 |
+
|
6 |
+
def save(self, file):
|
7 |
+
os.makedirs(os.path.split(file)[0], exist_ok=True)
|
8 |
+
json.dump(self.config, open(file, 'w'), indent=4)
|
9 |
+
|
10 |
+
def read_cfg(self, file):
|
11 |
+
# Its a json file with comments
|
12 |
+
new_lines = []
|
13 |
+
for line in open(file).readlines():
|
14 |
+
if line.find("#")!=-1:
|
15 |
+
new_lines.append(line[:line.find("#")])
|
16 |
+
else:
|
17 |
+
new_lines.append(line)
|
18 |
+
return json.loads('\n'.join(new_lines))
|
19 |
+
|
20 |
+
|
21 |
+
def merge_config(self, cfg_dict, base_dict):
|
22 |
+
for key in cfg_dict:
|
23 |
+
if key not in base_dict:
|
24 |
+
# Strange, raise an error
|
25 |
+
raise Exception(f'Key {key} not found in base config')
|
26 |
+
if isinstance(cfg_dict[key], dict):
|
27 |
+
base_dict[key] = self.merge_config(cfg_dict[key], base_dict[key])
|
28 |
+
else:
|
29 |
+
base_dict[key] = cfg_dict[key]
|
30 |
+
return base_dict
|
31 |
+
|
32 |
+
def __init__(self, file, base_file = 'configs/default.cfg') -> None:
|
33 |
+
self.default_config = self.read_cfg(base_file)
|
34 |
+
self.new_config = self.read_cfg(file)
|
35 |
+
self.config = self.merge_config(self.new_config, self.default_config)
|
36 |
+
|
DenseMammogram/advanced_logger.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# An Advanced Logger class which writes data in a well formatted manner
|
2 |
+
# to files based on different priorities.
|
3 |
+
|
4 |
+
from enum import Enum
|
5 |
+
import os
|
6 |
+
import datetime
|
7 |
+
import time
|
8 |
+
|
9 |
+
class LogPriority(Enum):
|
10 |
+
"""
|
11 |
+
Enum class for different log priorities.
|
12 |
+
"""
|
13 |
+
LOW = 0
|
14 |
+
MEDIUM = 1
|
15 |
+
HIGH = 2
|
16 |
+
STATS = 3
|
17 |
+
|
18 |
+
class AdvancedLogger:
|
19 |
+
|
20 |
+
def __init__(self, base_dir):
|
21 |
+
self.base_dir = base_dir
|
22 |
+
self.files = []
|
23 |
+
self.file_names = []
|
24 |
+
for p in LogPriority:
|
25 |
+
self.file_names.append(os.path.join(self.base_dir, f'Log_{p.name}' + '.log'))
|
26 |
+
self.files.append(open(self.file_names[-1], 'w'))
|
27 |
+
self.last_log_time = -1
|
28 |
+
|
29 |
+
def flush(self):
|
30 |
+
for f in self.files:
|
31 |
+
f.close()
|
32 |
+
for i in range(len(self.files)):
|
33 |
+
self.files[i] = open(self.file_names[i], 'a')
|
34 |
+
|
35 |
+
def log(self, *args, priority = LogPriority.LOW):
|
36 |
+
to_log = ' '.join(map(str, args))
|
37 |
+
if priority.value <= LogPriority.MEDIUM.value:
|
38 |
+
# Add current time to to_log
|
39 |
+
now = datetime.datetime.now()
|
40 |
+
to_log = f'[{now.strftime("%H:%M:%S")}]: {to_log}'
|
41 |
+
print(to_log)
|
42 |
+
for p in range(priority.value+1):
|
43 |
+
self.files[p].write(to_log + '\n')
|
44 |
+
|
45 |
+
# If time - last_log_time is greater than 10s or Priority is HIGH or above close the file and re-open in append mode
|
46 |
+
if time.time() - self.last_log_time > 10 or priority.value >= LogPriority.HIGH.value:
|
47 |
+
self.flush()
|
48 |
+
self.last_log_time = time.time()
|
DenseMammogram/all_graphs.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import join
|
3 |
+
from merge_predictions import get_image_dict, apply_merge
|
4 |
+
from froc_by_pranjal import calc_froc_from_dict, pretty_print_fps
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
|
9 |
+
OUT_DIR = 'euro_results_auto'
|
10 |
+
numbers_dir = os.path.join(OUT_DIR, 'numbers')
|
11 |
+
graphs_dir = os.path.join(OUT_DIR, 'graphs')
|
12 |
+
|
13 |
+
BASE_FOLDER = '../bilateral_new/MammoDatasets'
|
14 |
+
|
15 |
+
MIN_CLIP_FPI = 0.02
|
16 |
+
def plot_froc(input_files, save_file, TITLE = 'FRCNN vs BILATERAL FROC', SHOW = False, CLIP_FPI = 1.2):
|
17 |
+
for file in input_files:
|
18 |
+
lines = open(file).readlines()
|
19 |
+
x = np.array([float(line.split()[0]) for line in lines])
|
20 |
+
y = np.array([float(line.split()[1]) for line in lines])
|
21 |
+
y = y[x<CLIP_FPI]
|
22 |
+
x = x[x<CLIP_FPI]
|
23 |
+
y = y[MIN_CLIP_FPI<x]
|
24 |
+
x = x[MIN_CLIP_FPI<x]
|
25 |
+
plt.plot(x, y, label = input_files[file])
|
26 |
+
plt.legend()
|
27 |
+
|
28 |
+
plt.title(TITLE)
|
29 |
+
plt.xlabel('Average False Positive Per Image')
|
30 |
+
plt.ylabel('Sensitivity')
|
31 |
+
|
32 |
+
if SHOW:
|
33 |
+
plt.show()
|
34 |
+
plt.savefig(save_file)
|
35 |
+
plt.clf()
|
36 |
+
|
37 |
+
|
38 |
+
dsets = [('AIIMS_highres_reliable', 'AIIMS'), ('IRCHVal', 'IRCHVal')]
|
39 |
+
dsets = dsets[1:]
|
40 |
+
for dset in dsets:
|
41 |
+
test_splits = ['test_2', 'test_dense', 'test_iso'][::-1]
|
42 |
+
for test_split in test_splits:
|
43 |
+
main_dataset = join(BASE_FOLDER, dset[0], test_split)
|
44 |
+
|
45 |
+
contrast_datasets = [join(BASE_FOLDER,f'{dset[1]}_C{i+1}',test_split) for i in range(4)]
|
46 |
+
threshold_datasets = [join(BASE_FOLDER,f'{dset[1]}_T{i+1}',test_split) for i in range(2)]
|
47 |
+
frcnn_preds = 'preds_frcnn_frcnn'
|
48 |
+
contrast_preds = [
|
49 |
+
'preds_frcnn_AIIMS_C1',
|
50 |
+
'preds_frcnn_AIIMS_C2',
|
51 |
+
'preds_frcnn_AIIMS_C3',
|
52 |
+
'preds_frcnn_AIIMS_C4',
|
53 |
+
]
|
54 |
+
bilateral_preds = 'preds_bilateral_BILATERAL'
|
55 |
+
threshold_preds = [
|
56 |
+
'preds_frcnn_AIIMS_T1',
|
57 |
+
'preds_frcnn_AIIMS_T2',
|
58 |
+
]
|
59 |
+
|
60 |
+
input_files = []
|
61 |
+
dataset_paths = [join(main_dataset, '{0}', frcnn_preds)]
|
62 |
+
dataset_paths +=[join(dset, '{0}', preds) for (dset,preds) in zip(contrast_datasets, contrast_preds)]
|
63 |
+
dataset_paths +=[join(dset, '{0}', preds) for (dset,preds) in zip(threshold_datasets, threshold_preds)]
|
64 |
+
dataset_paths +=[join(main_dataset, '{0}', bilateral_preds)]
|
65 |
+
|
66 |
+
|
67 |
+
CONFIGS = {
|
68 |
+
'Baseline' : ('Baseline Model', [0]),
|
69 |
+
'Bilateral' : ('Bilateral Model', [7]),
|
70 |
+
'Contrast' : ('CABD Model', [0,1,2,3,4]),
|
71 |
+
'Threshold' : ('TI Model', [0,5,6]),
|
72 |
+
'Proposed' : ('Proposed Model', [1,2,3,4,5,6,7])
|
73 |
+
}
|
74 |
+
|
75 |
+
# Now handle the directories
|
76 |
+
num_dir = os.path.join(numbers_dir, dset[1], test_split)
|
77 |
+
os.makedirs(num_dir, exist_ok=True)
|
78 |
+
|
79 |
+
|
80 |
+
for config in CONFIGS:
|
81 |
+
title = CONFIGS[config][0]
|
82 |
+
allowed = CONFIGS[config][1]
|
83 |
+
|
84 |
+
weight_map = {
|
85 |
+
0 : 1.,
|
86 |
+
1 : 1,
|
87 |
+
2 : 1.,
|
88 |
+
3 : 1.,
|
89 |
+
4 : .5, # C4
|
90 |
+
5 : 0.5,
|
91 |
+
6 : 0.5,
|
92 |
+
7 : 1
|
93 |
+
}
|
94 |
+
|
95 |
+
weights = [weight_map[x] for x in allowed]
|
96 |
+
|
97 |
+
# generate the required mp dicts
|
98 |
+
def c2_manp(preds):
|
99 |
+
preds = list(filter(lambda x: x[0]>0.85,preds)) # keep preds lower than 0.6 confidence
|
100 |
+
return preds
|
101 |
+
|
102 |
+
def c3_manp(preds):
|
103 |
+
preds = list(filter(lambda x: x[0]>0.85,preds)) # keep preds lower than 0.6 confidence
|
104 |
+
return preds
|
105 |
+
|
106 |
+
def t1_manp(preds):
|
107 |
+
preds = list(filter(lambda x: x[0]>0.6,preds)) # keep preds lower than 0.6 confidence
|
108 |
+
return preds
|
109 |
+
|
110 |
+
t2_manp = t1_manp
|
111 |
+
mp_dict = {
|
112 |
+
f'{dset[1]}_C2' : c2_manp,
|
113 |
+
f'{dset[1]}_C3' : c3_manp,
|
114 |
+
f'{dset[1]}_T1' : t1_manp,
|
115 |
+
f'{dset[1]}_T2' : t2_manp,
|
116 |
+
f'{dset[1]}_C4' : c3_manp
|
117 |
+
}
|
118 |
+
|
119 |
+
image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = False, acr_cat = None, mp_dict = mp_dict)
|
120 |
+
image_dict = apply_merge(image_dict, METHOD = 'nms', weights= weights, conf_type='absent_model_aware_avg')
|
121 |
+
|
122 |
+
|
123 |
+
senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.], save_to = os.path.join(num_dir, f'{title}.txt'))
|
124 |
+
|
125 |
+
|
126 |
+
# Lets plot now
|
127 |
+
|
128 |
+
GRAPHS = [
|
129 |
+
('Bilateral','Baseline'),
|
130 |
+
('Contrast','Baseline'),
|
131 |
+
('Threshold','Baseline'),
|
132 |
+
('Proposed','Baseline'),
|
133 |
+
('Proposed', 'Bilateral'),
|
134 |
+
('Proposed', 'Contrast'),
|
135 |
+
('Proposed', 'Threshold'),
|
136 |
+
]
|
137 |
+
|
138 |
+
|
139 |
+
# Now handle the directories
|
140 |
+
graph_dir = os.path.join(graphs_dir, dset[1], test_split)
|
141 |
+
os.makedirs(graph_dir, exist_ok=True)
|
142 |
+
|
143 |
+
for graph in GRAPHS:
|
144 |
+
if graph[0] not in CONFIGS or graph[1] not in CONFIGS: continue
|
145 |
+
file_name1 = f'{CONFIGS[graph[0]][0]}.txt'
|
146 |
+
file_name2 = f'{CONFIGS[graph[1]][0]}.txt'
|
147 |
+
|
148 |
+
title1 = CONFIGS[graph[0]][0]
|
149 |
+
title2 = CONFIGS[graph[1]][0]
|
150 |
+
|
151 |
+
plot_froc({
|
152 |
+
join(num_dir, file_name1): title1,
|
153 |
+
join(num_dir, file_name2) : title2,
|
154 |
+
}, join(graph_dir,f'{title1}_vs_{title2}.png'),f'{title1} vs {title2} FROC', CLIP_FPI = 0.3 if dset[0] == 'IRCHVal' else 0.8)
|
155 |
+
|
156 |
+
|
DenseMammogram/auc_by_pranjal.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import join
|
3 |
+
import glob
|
4 |
+
from sklearn.metrics import roc_auc_score, roc_curve
|
5 |
+
import sys
|
6 |
+
|
7 |
+
def file_to_score(file):
|
8 |
+
try:
|
9 |
+
content = open(file, 'r').readlines()
|
10 |
+
st = 0
|
11 |
+
if len(content) == 0:
|
12 |
+
# Empty File Should Return []
|
13 |
+
return 0.
|
14 |
+
if content[0].split()[0].isalpha():
|
15 |
+
st = 1
|
16 |
+
return max([float(line.split()[st]) for line in content])
|
17 |
+
except FileNotFoundError:
|
18 |
+
print(f'No Corresponding Box Found for file {file}, using [] as preds')
|
19 |
+
return []
|
20 |
+
except Exception as e:
|
21 |
+
print('Some Error',e)
|
22 |
+
return []
|
23 |
+
|
24 |
+
# Create the image dict
|
25 |
+
def generate_image_dict(preds_folder_name='preds_42',
|
26 |
+
root_fol='/home/krithika_1/densebreeast_datasets/AIIMS_C1',
|
27 |
+
mal_path=None, ben_path=None, gt_path=None,
|
28 |
+
mal_img_path = None, ben_img_path = None
|
29 |
+
):
|
30 |
+
|
31 |
+
mal_path = join(root_fol, mal_path) if mal_path else join(
|
32 |
+
root_fol, 'mal', preds_folder_name)
|
33 |
+
ben_path = join(root_fol, ben_path) if ben_path else join(
|
34 |
+
root_fol, 'ben', preds_folder_name)
|
35 |
+
mal_img_path = join(root_fol, mal_img_path) if mal_img_path else join(
|
36 |
+
root_fol, 'mal', 'images')
|
37 |
+
ben_img_path = join(root_fol, ben_img_path) if ben_img_path else join(
|
38 |
+
root_fol, 'ben', 'images')
|
39 |
+
gt_path = join(root_fol, gt_path) if gt_path else join(
|
40 |
+
root_fol, 'mal', 'gt')
|
41 |
+
|
42 |
+
|
43 |
+
'''
|
44 |
+
image_dict structure:
|
45 |
+
'image_name(without txt/png)' : {'gt' : [[...]], 'preds' : score}
|
46 |
+
'''
|
47 |
+
image_dict = dict()
|
48 |
+
|
49 |
+
# GT Might be sightly different from images, therefore we will index gts based on
|
50 |
+
# the images folder instead.
|
51 |
+
for file in os.listdir(mal_img_path):
|
52 |
+
# for file in glob.glob(join(gt_path, '*.txt')):
|
53 |
+
if not file.endswith('.png'):
|
54 |
+
continue
|
55 |
+
file = file[:-4] + '.txt'
|
56 |
+
file = join(gt_path, file)
|
57 |
+
key = os.path.split(file)[-1][:-4]
|
58 |
+
image_dict[key] = dict()
|
59 |
+
image_dict[key]['gt'] = 1.
|
60 |
+
image_dict[key]['preds'] = 0.
|
61 |
+
|
62 |
+
for file in glob.glob(join(mal_path, '*.txt')):
|
63 |
+
key = os.path.split(file)[-1][:-4]
|
64 |
+
assert key in image_dict
|
65 |
+
image_dict[key]['preds'] = file_to_score(file)
|
66 |
+
|
67 |
+
for file in os.listdir(ben_img_path):
|
68 |
+
# for file in glob.glob(join(ben_path, '*.txt')):
|
69 |
+
if not file.endswith('.png'):
|
70 |
+
continue
|
71 |
+
|
72 |
+
file = file[:-4] + '.txt'
|
73 |
+
file = join(ben_path, file)
|
74 |
+
key = os.path.split(file)[-1][:-4]
|
75 |
+
# if key == 'Calc-Test_P_00353_LEFT_CC' or key == 'Calc-Training_P_00600_LEFT_CC':
|
76 |
+
# continue
|
77 |
+
if key in image_dict:
|
78 |
+
print(key)
|
79 |
+
print('SHIT')
|
80 |
+
continue
|
81 |
+
# assert key not in image_dict
|
82 |
+
image_dict[key] = dict()
|
83 |
+
image_dict[key]['preds'] = file_to_score(file)
|
84 |
+
image_dict[key]['gt'] = 0.
|
85 |
+
return image_dict
|
86 |
+
|
87 |
+
def get_auc_score_from_imdict(image_dict):
|
88 |
+
keys = list(image_dict.keys())
|
89 |
+
y = [image_dict[k]['gt']for k in keys]
|
90 |
+
preds = [image_dict[k]['preds']for k in keys]
|
91 |
+
return roc_auc_score(y, preds)
|
92 |
+
|
93 |
+
def get_accuracy_from_imdict(image_dict, thresh = 0.3):
|
94 |
+
keys = list(image_dict.keys())
|
95 |
+
ys = [image_dict[k]['gt']for k in keys]
|
96 |
+
preds = [image_dict[k]['preds']for k in keys]
|
97 |
+
acc = 0
|
98 |
+
for y,pred in zip(ys,preds):
|
99 |
+
if pred < thresh and y == 0.:
|
100 |
+
acc+=1
|
101 |
+
elif pred > thresh and y == 1.:
|
102 |
+
acc+=1
|
103 |
+
return acc/len(preds)
|
104 |
+
|
105 |
+
|
106 |
+
def get_auc_score(preds_image_folder, root_fol, retAcc = False, acc_thresh = 0.3):
|
107 |
+
im_dict = generate_image_dict(preds_image_folder, root_fol = root_fol)
|
108 |
+
if retAcc:
|
109 |
+
return get_auc_score_from_imdict(im_dict), get_accuracy_from_imdict(im_dict, acc_thresh)
|
110 |
+
else:
|
111 |
+
return get_auc_score_from_imdict(im_dict)
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
seed = '42' if len(sys.argv)== 1 else sys.argv[1]
|
115 |
+
|
116 |
+
root_fol = '../bilateral_new/MammoDatasets/AIIMS_highres_reliable/test'
|
117 |
+
|
118 |
+
auc_score = get_auc_score(f'preds_{seed}',root_fol)
|
119 |
+
print(f'ROC AUC Score: {auc_score}')
|
120 |
+
|
DenseMammogram/dataloaders.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Get the dataloaders
|
2 |
+
# There are only two types of dataloaders, viz. VanillaFRCNN and BilaterialFRCNN
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import cv2
|
6 |
+
import torchvision.transforms as T
|
7 |
+
import detection.transforms as transforms
|
8 |
+
from torch.utils.data import Dataset,DataLoader
|
9 |
+
import detection.utils as utils
|
10 |
+
import os
|
11 |
+
from tqdm import tqdm
|
12 |
+
import pandas as pd
|
13 |
+
from os.path import join
|
14 |
+
# VanillaFRCNN DataLoaders
|
15 |
+
|
16 |
+
class FRCNNDataset(Dataset):
|
17 |
+
def __init__(self,inputs,transform):
|
18 |
+
self.transform = transform
|
19 |
+
self.dataset_dicts = inputs
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.dataset_dicts)
|
23 |
+
|
24 |
+
|
25 |
+
def __getitem__(self,index: int):
|
26 |
+
# Select the sample
|
27 |
+
record = self.dataset_dicts[index]
|
28 |
+
# Load input and target
|
29 |
+
img = cv2.imread(record['file_name'])
|
30 |
+
|
31 |
+
target = {k:torch.tensor(v) for k,v in record.items() if k != 'file_name'}
|
32 |
+
if self.transform is not None:
|
33 |
+
img = T.ToPILImage()(img)
|
34 |
+
img,target = self.transform(img,target)
|
35 |
+
|
36 |
+
return img,target
|
37 |
+
|
38 |
+
def xml_to_dicts(paths):
|
39 |
+
dataset_dicts = []
|
40 |
+
i=1
|
41 |
+
for path in paths:
|
42 |
+
for image in tqdm(os.listdir(os.path.join(path,'mal/images/'))):
|
43 |
+
xmlfile = os.path.join(path,'mal/gt/',image[:-4]+'.txt')
|
44 |
+
if(not os.path.exists(xmlfile)):
|
45 |
+
continue
|
46 |
+
img = cv2.imread(os.path.join(path,'mal/images/',image))
|
47 |
+
record = {}
|
48 |
+
record['file_name'] = os.path.join(path , 'mal/images/',image)
|
49 |
+
record['image_id'] = i
|
50 |
+
i+=1
|
51 |
+
record['width'] = img.shape[1]
|
52 |
+
record['height'] = img.shape[0]
|
53 |
+
objs = []
|
54 |
+
boxes = []
|
55 |
+
labels = []
|
56 |
+
area = []
|
57 |
+
iscrowd = []
|
58 |
+
f = open(xmlfile,'r')
|
59 |
+
for line in f.readlines():
|
60 |
+
box = list(map(int,map(float,line.split()[1:])))
|
61 |
+
boxes.append(box)
|
62 |
+
labels.append(1)
|
63 |
+
area.append((box[2]-box[0])*(box[3]-box[1]))
|
64 |
+
iscrowd.append(False)
|
65 |
+
f.close()
|
66 |
+
record["boxes"] = boxes
|
67 |
+
record["labels"] = labels
|
68 |
+
record["area"] = area
|
69 |
+
record["iscrowd"] = iscrowd
|
70 |
+
if(len(boxes)>0):
|
71 |
+
dataset_dicts.append(record)
|
72 |
+
for image in tqdm(os.listdir(os.path.join(path,'ben/images/'))):
|
73 |
+
img = cv2.imread(os.path.join(path,'ben/images/',image))
|
74 |
+
record = {}
|
75 |
+
record['file_name'] = os.path.join(path, 'ben/images/',image)
|
76 |
+
record['image_id'] = i
|
77 |
+
i+=1
|
78 |
+
record['width'] = img.shape[1]
|
79 |
+
record['height'] = img.shape[0]
|
80 |
+
record['boxes'] = torch.tensor([[0,0,img.shape[1],img.shape[0]]])
|
81 |
+
record['labels'] = torch.tensor([0])
|
82 |
+
record['area'] = [img.shape[1]*img.shape[0]]
|
83 |
+
record["iscrowd"] = [False]
|
84 |
+
dataset_dicts.append(record)
|
85 |
+
return dataset_dicts
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def get_FRCNN_dataloaders(cfg, batch_size = 2, data_dir = '../bilateral_new',):
|
90 |
+
transform_test = transforms.Compose([transforms.ToTensor()])
|
91 |
+
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
|
92 |
+
train_paths = [join(data_dir,cfg['AIIMS_DATA'],cfg['AIIMS_TRAIN_SPLIT']),join(data_dir,cfg['DDSM_DATA'],cfg['DDSM_TRAIN_SPLIT']),]
|
93 |
+
val_aiims_path = [join(data_dir,cfg['AIIMS_DATA'],cfg['AIIMS_VAL_SPLIT'])]
|
94 |
+
train_data = FRCNNDataset(xml_to_dicts(train_paths),transform_train)
|
95 |
+
test_aiims = FRCNNDataset(xml_to_dicts(val_aiims_path),transform_test)
|
96 |
+
|
97 |
+
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=4,collate_fn = utils.collate_fn)
|
98 |
+
test_aiims_loader = DataLoader(test_aiims,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=4,collate_fn = utils.collate_fn)
|
99 |
+
#test_ddsm_loader = DataLoader(test_ddsm,batch_size=2,shuffle=True,drop_last=True,num_workers=5,collate_fn = utils.collate_fn)
|
100 |
+
|
101 |
+
return train_loader, test_aiims_loader
|
102 |
+
|
103 |
+
# BilaterialFRCNN DataLoaders
|
104 |
+
|
105 |
+
def get_direction(dset,file_name):
|
106 |
+
# 1 if right else -1
|
107 |
+
if dset == 'aiims' or dset == 'ddsm':
|
108 |
+
file_name = file_name.lower()
|
109 |
+
r = file_name.find('right')
|
110 |
+
l = file_name.find('left')
|
111 |
+
if l == r and l == -1:
|
112 |
+
raise Exception(f'Unidentifiable Direction {file_name}')
|
113 |
+
if l!=-1 and r!=-1:
|
114 |
+
raise Exception(f'Unidentifiable Direction {file_name}')
|
115 |
+
return 1 if r!=-1 else -1
|
116 |
+
if dset == 'inbreast':
|
117 |
+
dir =file_name.split('_')[3]
|
118 |
+
if dir == 'R': return 1
|
119 |
+
if dir == 'L': return -1
|
120 |
+
raise Exception(f'Unidentifiable Direction {file_name}')
|
121 |
+
if dset == 'irch':
|
122 |
+
r = file_name.find('_R ')
|
123 |
+
l = file_name.find('_L ')
|
124 |
+
if l == r and l == -1:
|
125 |
+
raise Exception(f'Unidentifiable Direction {file_name}')
|
126 |
+
if l!=-1 and r!=-1:
|
127 |
+
raise Exception(f'Unidentifiable Direction {file_name}')
|
128 |
+
return 1 if r!=-1 else -1
|
129 |
+
|
130 |
+
|
131 |
+
class BilateralDataset(torch.utils.data.Dataset):
|
132 |
+
|
133 |
+
def __init__(self,inputs,transform,dset):
|
134 |
+
self.transform = transform
|
135 |
+
self.dataset_dicts = inputs
|
136 |
+
self.dset = dset
|
137 |
+
|
138 |
+
def __len__(self):
|
139 |
+
return len(self.dataset_dicts)
|
140 |
+
|
141 |
+
|
142 |
+
def __getitem__(self,index: int):
|
143 |
+
# Select the sample
|
144 |
+
record = self.dataset_dicts[index]
|
145 |
+
# Load input and target
|
146 |
+
img1 = cv2.imread(record['file_name'])
|
147 |
+
img2 = cv2.imread(record['file_2'])
|
148 |
+
|
149 |
+
target = {k:torch.tensor(v) for k,v in record.items() if k != 'file_name' and k!='file_2'}
|
150 |
+
if self.transform is not None:
|
151 |
+
img1 = T.ToPILImage()(img1)
|
152 |
+
img2 = T.ToPILImage()(img2)
|
153 |
+
if(get_direction(self.dset,record['file_name'].split('/')[-1])==1):
|
154 |
+
img1,target = transforms.RandomHorizontalFlip(1.0)(img1,target)
|
155 |
+
else:
|
156 |
+
img2,_ = transforms.RandomHorizontalFlip(1.0)(img2)
|
157 |
+
img1,target = self.transform(img1,target)
|
158 |
+
img2,target = self.transform(img2,target)
|
159 |
+
|
160 |
+
images = [img1,img2]
|
161 |
+
return images,target
|
162 |
+
|
163 |
+
|
164 |
+
def xml_to_dicts_bilateral(paths,cor_dicts):
|
165 |
+
dataset_dicts = []
|
166 |
+
i=1
|
167 |
+
for path,cor_dict in zip(paths,cor_dicts):
|
168 |
+
for image in tqdm(os.listdir(os.path.join(path,'mal/images/'))):
|
169 |
+
if(not os.path.join(path,'mal/images/',image) in cor_dict):
|
170 |
+
continue
|
171 |
+
if(not os.path.isfile(cor_dict[os.path.join(path,'mal/images/',image)])):
|
172 |
+
continue
|
173 |
+
xmlfile = os.path.join(path,'mal/gt/',image[:-4]+'.txt')
|
174 |
+
if(not os.path.exists(xmlfile)):
|
175 |
+
continue
|
176 |
+
img = cv2.imread(os.path.join(path,'mal/images/',image))
|
177 |
+
|
178 |
+
record = {}
|
179 |
+
record['file_name'] = os.path.join(path , 'mal/images/',image)
|
180 |
+
record['file_2'] = cor_dict[os.path.join(path,'mal/images/',image)]
|
181 |
+
record['image_id'] = i
|
182 |
+
i+=1
|
183 |
+
record['width'] = img.shape[1]
|
184 |
+
record['height'] = img.shape[0]
|
185 |
+
objs = []
|
186 |
+
boxes = []
|
187 |
+
labels = []
|
188 |
+
area = []
|
189 |
+
iscrowd = []
|
190 |
+
f = open(xmlfile,'r')
|
191 |
+
for line in f.readlines():
|
192 |
+
box = list(map(int,map(float,line.split()[1:])))
|
193 |
+
boxes.append(box)
|
194 |
+
labels.append(1)
|
195 |
+
area.append((box[2]-box[0])*(box[3]-box[1]))
|
196 |
+
iscrowd.append(False)
|
197 |
+
|
198 |
+
f.close()
|
199 |
+
record["boxes"] = boxes
|
200 |
+
record["labels"] = labels
|
201 |
+
record["area"] = area
|
202 |
+
record["iscrowd"] = iscrowd
|
203 |
+
if(len(boxes)>0):
|
204 |
+
dataset_dicts.append(record)
|
205 |
+
|
206 |
+
for image in tqdm(os.listdir(os.path.join(path,'ben/images/'))):
|
207 |
+
if(not os.path.join(path,'ben/images/',image) in cor_dict):
|
208 |
+
continue
|
209 |
+
if(not os.path.isfile(cor_dict[os.path.join(path,'ben/images/',image)])):
|
210 |
+
continue
|
211 |
+
img = cv2.imread(os.path.join(path,'ben/images/',image))
|
212 |
+
|
213 |
+
record = {}
|
214 |
+
record['file_name'] = os.path.join(path , 'ben/images/',image)
|
215 |
+
record['file_2'] = cor_dict[os.path.join(path,'ben/images/',image)]
|
216 |
+
img2 = cv2.imread(cor_dict[os.path.join(path,'ben/images/',image)])
|
217 |
+
record['image_id'] = i
|
218 |
+
i+=1
|
219 |
+
record['width'] = img.shape[1]
|
220 |
+
record['height'] = img.shape[0]
|
221 |
+
|
222 |
+
record["boxes"] = torch.tensor([[0,0,min(img.shape[1],img2.shape[1]),min(img.shape[0],img2.shape[0])]])
|
223 |
+
record['labels'] = torch.tensor([0])
|
224 |
+
record['area'] = [ min(img.shape[1],img2.shape[1]) *min(img.shape[0],img2.shape[0])]
|
225 |
+
record["iscrowd"] = [False]
|
226 |
+
if(len(boxes)>0):
|
227 |
+
dataset_dicts.append(record)
|
228 |
+
|
229 |
+
return dataset_dicts
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
def get_dict(data_dir, filename):
|
234 |
+
df = pd.read_csv(filename, header=None, sep=r'\s+', quotechar='"').to_numpy()
|
235 |
+
cor_dict = dict()
|
236 |
+
for a in df:
|
237 |
+
if(a[0]==a[1]):
|
238 |
+
continue
|
239 |
+
cor_dict[a[0]] = a[1]
|
240 |
+
# print(cor_dict)
|
241 |
+
cor_dict = {join(data_dir,k):join(data_dir,v) for k,v in cor_dict.items()}
|
242 |
+
return cor_dict
|
243 |
+
|
244 |
+
def get_bilateral_dataloaders(cfg, batch_size = 1, data_dir = '../bilateral_new'):
|
245 |
+
transform_test = transforms.Compose([transforms.ToTensor()])
|
246 |
+
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
|
247 |
+
train_paths = [join(data_dir,cfg['AIIMS_DATA'],cfg['AIIMS_TRAIN_SPLIT']),join(data_dir,cfg['DDSM_DATA'],cfg['DDSM_TRAIN_SPLIT']),]
|
248 |
+
val_aiims_path = [join(data_dir,cfg['AIIMS_DATA'],cfg['AIIMS_VAL_SPLIT'])]
|
249 |
+
cor_lists_train = [get_dict(data_dir,join(data_dir,cfg['AIIMS_CORRS_LIST'])),get_dict(data_dir,join(data_dir,cfg['DDSM_CORRS_LIST']))]
|
250 |
+
cor_lists_val = [get_dict(data_dir,join(data_dir,cfg['AIIMS_CORRS_LIST']))]
|
251 |
+
cor_lists_train = [get_dict(data_dir,join(data_dir,cfg['AIIMS_CORRS_LIST']))]
|
252 |
+
train_data = BilateralDataset(xml_to_dicts_bilateral(train_paths,cor_lists_train),transform_test,'aiims')
|
253 |
+
val_aiims = BilateralDataset(xml_to_dicts_bilateral(val_aiims_path,cor_lists_val),transform_test,'aiims')
|
254 |
+
|
255 |
+
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=4,collate_fn = utils.collate_fn)
|
256 |
+
val_aiims_loader = DataLoader(val_aiims,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=4,collate_fn = utils.collate_fn)
|
257 |
+
#test_ddsm_loader = DataLoader(test_ddsm,batch_size=2,shuffle=True,drop_last=True,num_workers=5,collate_fn = utils.collate_fn)
|
258 |
+
|
259 |
+
return train_loader, val_aiims_loader
|
DenseMammogram/detection/README.md
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Object detection reference training scripts
|
2 |
+
|
3 |
+
This folder contains reference training scripts for object detection.
|
4 |
+
They serve as a log of how to train specific models, to provide baseline
|
5 |
+
training and evaluation scripts to quickly bootstrap research.
|
6 |
+
|
7 |
+
To execute the example commands below you must install the following:
|
8 |
+
|
9 |
+
```
|
10 |
+
cython
|
11 |
+
pycocotools
|
12 |
+
matplotlib
|
13 |
+
```
|
14 |
+
|
15 |
+
You must modify the following flags:
|
16 |
+
|
17 |
+
`--data-path=/path/to/coco/dataset`
|
18 |
+
|
19 |
+
`--nproc_per_node=<number_of_gpus_available>`
|
20 |
+
|
21 |
+
Except otherwise noted, all models have been trained on 8x V100 GPUs.
|
22 |
+
|
23 |
+
### Faster R-CNN ResNet-50 FPN
|
24 |
+
```
|
25 |
+
torchrun --nproc_per_node=8 train.py\
|
26 |
+
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
|
27 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3
|
28 |
+
```
|
29 |
+
|
30 |
+
### Faster R-CNN MobileNetV3-Large FPN
|
31 |
+
```
|
32 |
+
torchrun --nproc_per_node=8 train.py\
|
33 |
+
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
|
34 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3
|
35 |
+
```
|
36 |
+
|
37 |
+
### Faster R-CNN MobileNetV3-Large 320 FPN
|
38 |
+
```
|
39 |
+
torchrun --nproc_per_node=8 train.py\
|
40 |
+
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
|
41 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3
|
42 |
+
```
|
43 |
+
|
44 |
+
### RetinaNet
|
45 |
+
```
|
46 |
+
torchrun --nproc_per_node=8 train.py\
|
47 |
+
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
|
48 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
|
49 |
+
```
|
50 |
+
|
51 |
+
### SSD300 VGG16
|
52 |
+
```
|
53 |
+
torchrun --nproc_per_node=8 train.py\
|
54 |
+
--dataset coco --model ssd300_vgg16 --epochs 120\
|
55 |
+
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
|
56 |
+
--weight-decay 0.0005 --data-augmentation ssd
|
57 |
+
```
|
58 |
+
|
59 |
+
### SSDlite320 MobileNetV3-Large
|
60 |
+
```
|
61 |
+
torchrun --nproc_per_node=8 train.py\
|
62 |
+
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
|
63 |
+
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
|
64 |
+
--weight-decay 0.00004 --data-augmentation ssdlite
|
65 |
+
```
|
66 |
+
|
67 |
+
|
68 |
+
### Mask R-CNN
|
69 |
+
```
|
70 |
+
torchrun --nproc_per_node=8 train.py\
|
71 |
+
--dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
|
72 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3
|
73 |
+
```
|
74 |
+
|
75 |
+
|
76 |
+
### Keypoint R-CNN
|
77 |
+
```
|
78 |
+
torchrun --nproc_per_node=8 train.py\
|
79 |
+
--dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
|
80 |
+
--lr-steps 36 43 --aspect-ratio-group-factor 3
|
81 |
+
```
|
DenseMammogram/detection/coco_eval.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import io
|
3 |
+
from contextlib import redirect_stdout
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pycocotools.mask as mask_util
|
7 |
+
import torch
|
8 |
+
import detection.utils as utils
|
9 |
+
from pycocotools.coco import COCO
|
10 |
+
from pycocotools.cocoeval import COCOeval
|
11 |
+
|
12 |
+
|
13 |
+
class CocoEvaluator:
|
14 |
+
def __init__(self, coco_gt, iou_types):
|
15 |
+
assert isinstance(iou_types, (list, tuple))
|
16 |
+
coco_gt = copy.deepcopy(coco_gt)
|
17 |
+
self.coco_gt = coco_gt
|
18 |
+
|
19 |
+
self.iou_types = iou_types
|
20 |
+
self.coco_eval = {}
|
21 |
+
for iou_type in iou_types:
|
22 |
+
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
23 |
+
|
24 |
+
self.img_ids = []
|
25 |
+
self.eval_imgs = {k: [] for k in iou_types}
|
26 |
+
|
27 |
+
def update(self, predictions):
|
28 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
29 |
+
self.img_ids.extend(img_ids)
|
30 |
+
|
31 |
+
for iou_type in self.iou_types:
|
32 |
+
results = self.prepare(predictions, iou_type)
|
33 |
+
with redirect_stdout(io.StringIO()):
|
34 |
+
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
35 |
+
coco_eval = self.coco_eval[iou_type]
|
36 |
+
|
37 |
+
coco_eval.cocoDt = coco_dt
|
38 |
+
coco_eval.params.imgIds = list(img_ids)
|
39 |
+
img_ids, eval_imgs = evaluate(coco_eval)
|
40 |
+
|
41 |
+
self.eval_imgs[iou_type].append(eval_imgs)
|
42 |
+
|
43 |
+
def synchronize_between_processes(self):
|
44 |
+
for iou_type in self.iou_types:
|
45 |
+
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
46 |
+
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
47 |
+
|
48 |
+
def accumulate(self):
|
49 |
+
for coco_eval in self.coco_eval.values():
|
50 |
+
coco_eval.accumulate()
|
51 |
+
|
52 |
+
def summarize(self):
|
53 |
+
for iou_type, coco_eval in self.coco_eval.items():
|
54 |
+
print(f"IoU metric: {iou_type}")
|
55 |
+
coco_eval.summarize()
|
56 |
+
|
57 |
+
def prepare(self, predictions, iou_type):
|
58 |
+
if iou_type == "bbox":
|
59 |
+
return self.prepare_for_coco_detection(predictions)
|
60 |
+
if iou_type == "segm":
|
61 |
+
return self.prepare_for_coco_segmentation(predictions)
|
62 |
+
if iou_type == "keypoints":
|
63 |
+
return self.prepare_for_coco_keypoint(predictions)
|
64 |
+
raise ValueError(f"Unknown iou type {iou_type}")
|
65 |
+
|
66 |
+
def prepare_for_coco_detection(self, predictions):
|
67 |
+
coco_results = []
|
68 |
+
for original_id, prediction in predictions.items():
|
69 |
+
if len(prediction) == 0:
|
70 |
+
continue
|
71 |
+
|
72 |
+
boxes = prediction["boxes"]
|
73 |
+
boxes = convert_to_xywh(boxes).tolist()
|
74 |
+
scores = prediction["scores"].tolist()
|
75 |
+
labels = prediction["labels"].tolist()
|
76 |
+
|
77 |
+
coco_results.extend(
|
78 |
+
[
|
79 |
+
{
|
80 |
+
"image_id": original_id,
|
81 |
+
"category_id": labels[k],
|
82 |
+
"bbox": box,
|
83 |
+
"score": scores[k],
|
84 |
+
}
|
85 |
+
for k, box in enumerate(boxes)
|
86 |
+
]
|
87 |
+
)
|
88 |
+
return coco_results
|
89 |
+
|
90 |
+
def prepare_for_coco_segmentation(self, predictions):
|
91 |
+
coco_results = []
|
92 |
+
for original_id, prediction in predictions.items():
|
93 |
+
if len(prediction) == 0:
|
94 |
+
continue
|
95 |
+
|
96 |
+
scores = prediction["scores"]
|
97 |
+
labels = prediction["labels"]
|
98 |
+
masks = prediction["masks"]
|
99 |
+
|
100 |
+
masks = masks > 0.5
|
101 |
+
|
102 |
+
scores = prediction["scores"].tolist()
|
103 |
+
labels = prediction["labels"].tolist()
|
104 |
+
|
105 |
+
rles = [
|
106 |
+
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
|
107 |
+
]
|
108 |
+
for rle in rles:
|
109 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
110 |
+
|
111 |
+
coco_results.extend(
|
112 |
+
[
|
113 |
+
{
|
114 |
+
"image_id": original_id,
|
115 |
+
"category_id": labels[k],
|
116 |
+
"segmentation": rle,
|
117 |
+
"score": scores[k],
|
118 |
+
}
|
119 |
+
for k, rle in enumerate(rles)
|
120 |
+
]
|
121 |
+
)
|
122 |
+
return coco_results
|
123 |
+
|
124 |
+
def prepare_for_coco_keypoint(self, predictions):
|
125 |
+
coco_results = []
|
126 |
+
for original_id, prediction in predictions.items():
|
127 |
+
if len(prediction) == 0:
|
128 |
+
continue
|
129 |
+
|
130 |
+
boxes = prediction["boxes"]
|
131 |
+
boxes = convert_to_xywh(boxes).tolist()
|
132 |
+
scores = prediction["scores"].tolist()
|
133 |
+
labels = prediction["labels"].tolist()
|
134 |
+
keypoints = prediction["keypoints"]
|
135 |
+
keypoints = keypoints.flatten(start_dim=1).tolist()
|
136 |
+
|
137 |
+
coco_results.extend(
|
138 |
+
[
|
139 |
+
{
|
140 |
+
"image_id": original_id,
|
141 |
+
"category_id": labels[k],
|
142 |
+
"keypoints": keypoint,
|
143 |
+
"score": scores[k],
|
144 |
+
}
|
145 |
+
for k, keypoint in enumerate(keypoints)
|
146 |
+
]
|
147 |
+
)
|
148 |
+
return coco_results
|
149 |
+
|
150 |
+
|
151 |
+
def convert_to_xywh(boxes):
|
152 |
+
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
153 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
154 |
+
|
155 |
+
|
156 |
+
def merge(img_ids, eval_imgs):
|
157 |
+
all_img_ids = utils.all_gather(img_ids)
|
158 |
+
all_eval_imgs = utils.all_gather(eval_imgs)
|
159 |
+
|
160 |
+
merged_img_ids = []
|
161 |
+
for p in all_img_ids:
|
162 |
+
merged_img_ids.extend(p)
|
163 |
+
|
164 |
+
merged_eval_imgs = []
|
165 |
+
for p in all_eval_imgs:
|
166 |
+
merged_eval_imgs.append(p)
|
167 |
+
|
168 |
+
merged_img_ids = np.array(merged_img_ids)
|
169 |
+
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
170 |
+
|
171 |
+
# keep only unique (and in sorted order) images
|
172 |
+
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
173 |
+
merged_eval_imgs = merged_eval_imgs[..., idx]
|
174 |
+
|
175 |
+
return merged_img_ids, merged_eval_imgs
|
176 |
+
|
177 |
+
|
178 |
+
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
179 |
+
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
180 |
+
img_ids = list(img_ids)
|
181 |
+
eval_imgs = list(eval_imgs.flatten())
|
182 |
+
|
183 |
+
coco_eval.evalImgs = eval_imgs
|
184 |
+
coco_eval.params.imgIds = img_ids
|
185 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
186 |
+
|
187 |
+
|
188 |
+
def evaluate(imgs):
|
189 |
+
with redirect_stdout(io.StringIO()):
|
190 |
+
imgs.evaluate()
|
191 |
+
return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
|
DenseMammogram/detection/coco_utils.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
import torchvision
|
7 |
+
import detection.transforms as T
|
8 |
+
from pycocotools import mask as coco_mask
|
9 |
+
from pycocotools.coco import COCO
|
10 |
+
|
11 |
+
|
12 |
+
class FilterAndRemapCocoCategories:
|
13 |
+
def __init__(self, categories, remap=True):
|
14 |
+
self.categories = categories
|
15 |
+
self.remap = remap
|
16 |
+
|
17 |
+
def __call__(self, image, target):
|
18 |
+
anno = target["annotations"]
|
19 |
+
anno = [obj for obj in anno if obj["category_id"] in self.categories]
|
20 |
+
if not self.remap:
|
21 |
+
target["annotations"] = anno
|
22 |
+
return image, target
|
23 |
+
anno = copy.deepcopy(anno)
|
24 |
+
for obj in anno:
|
25 |
+
obj["category_id"] = self.categories.index(obj["category_id"])
|
26 |
+
target["annotations"] = anno
|
27 |
+
return image, target
|
28 |
+
|
29 |
+
|
30 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
31 |
+
masks = []
|
32 |
+
for polygons in segmentations:
|
33 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
34 |
+
mask = coco_mask.decode(rles)
|
35 |
+
if len(mask.shape) < 3:
|
36 |
+
mask = mask[..., None]
|
37 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
38 |
+
mask = mask.any(dim=2)
|
39 |
+
masks.append(mask)
|
40 |
+
if masks:
|
41 |
+
masks = torch.stack(masks, dim=0)
|
42 |
+
else:
|
43 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
44 |
+
return masks
|
45 |
+
|
46 |
+
|
47 |
+
class ConvertCocoPolysToMask:
|
48 |
+
def __call__(self, image, target):
|
49 |
+
w, h = image.size
|
50 |
+
|
51 |
+
image_id = target["image_id"]
|
52 |
+
image_id = torch.tensor([image_id])
|
53 |
+
|
54 |
+
anno = target["annotations"]
|
55 |
+
|
56 |
+
anno = [obj for obj in anno if obj["iscrowd"] == 0]
|
57 |
+
|
58 |
+
boxes = [obj["bbox"] for obj in anno]
|
59 |
+
# guard against no boxes via resizing
|
60 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
61 |
+
boxes[:, 2:] += boxes[:, :2]
|
62 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
63 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
64 |
+
|
65 |
+
classes = [obj["category_id"] for obj in anno]
|
66 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
67 |
+
|
68 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
69 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
70 |
+
|
71 |
+
keypoints = None
|
72 |
+
if anno and "keypoints" in anno[0]:
|
73 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
74 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
75 |
+
num_keypoints = keypoints.shape[0]
|
76 |
+
if num_keypoints:
|
77 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
78 |
+
|
79 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
80 |
+
boxes = boxes[keep]
|
81 |
+
classes = classes[keep]
|
82 |
+
masks = masks[keep]
|
83 |
+
if keypoints is not None:
|
84 |
+
keypoints = keypoints[keep]
|
85 |
+
|
86 |
+
target = {}
|
87 |
+
target["boxes"] = boxes
|
88 |
+
target["labels"] = classes
|
89 |
+
target["masks"] = masks
|
90 |
+
target["image_id"] = image_id
|
91 |
+
if keypoints is not None:
|
92 |
+
target["keypoints"] = keypoints
|
93 |
+
|
94 |
+
# for conversion to coco api
|
95 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
96 |
+
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
|
97 |
+
target["area"] = area
|
98 |
+
target["iscrowd"] = iscrowd
|
99 |
+
|
100 |
+
return image, target
|
101 |
+
|
102 |
+
|
103 |
+
def _coco_remove_images_without_annotations(dataset, cat_list=None):
|
104 |
+
def _has_only_empty_bbox(anno):
|
105 |
+
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
|
106 |
+
|
107 |
+
def _count_visible_keypoints(anno):
|
108 |
+
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
|
109 |
+
|
110 |
+
min_keypoints_per_image = 10
|
111 |
+
|
112 |
+
def _has_valid_annotation(anno):
|
113 |
+
# if it's empty, there is no annotation
|
114 |
+
if len(anno) == 0:
|
115 |
+
return False
|
116 |
+
# if all boxes have close to zero area, there is no annotation
|
117 |
+
if _has_only_empty_bbox(anno):
|
118 |
+
return False
|
119 |
+
# keypoints task have a slight different critera for considering
|
120 |
+
# if an annotation is valid
|
121 |
+
if "keypoints" not in anno[0]:
|
122 |
+
return True
|
123 |
+
# for keypoint detection tasks, only consider valid images those
|
124 |
+
# containing at least min_keypoints_per_image
|
125 |
+
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
|
126 |
+
return True
|
127 |
+
return False
|
128 |
+
|
129 |
+
assert isinstance(dataset, torchvision.datasets.CocoDetection)
|
130 |
+
ids = []
|
131 |
+
for ds_idx, img_id in enumerate(dataset.ids):
|
132 |
+
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
133 |
+
anno = dataset.coco.loadAnns(ann_ids)
|
134 |
+
if cat_list:
|
135 |
+
anno = [obj for obj in anno if obj["category_id"] in cat_list]
|
136 |
+
if _has_valid_annotation(anno):
|
137 |
+
ids.append(ds_idx)
|
138 |
+
|
139 |
+
dataset = torch.utils.data.Subset(dataset, ids)
|
140 |
+
return dataset
|
141 |
+
|
142 |
+
|
143 |
+
def convert_to_coco_api(ds):
|
144 |
+
coco_ds = COCO()
|
145 |
+
# annotation IDs need to start at 1, not 0, see torchvision issue #1530
|
146 |
+
ann_id = 1
|
147 |
+
dataset = {"images": [], "categories": [], "annotations": []}
|
148 |
+
categories = set()
|
149 |
+
for img_idx in range(len(ds)):
|
150 |
+
# find better way to get target
|
151 |
+
# targets = ds.get_annotations(img_idx)
|
152 |
+
img, targets = ds[img_idx]
|
153 |
+
image_id = targets["image_id"].item()
|
154 |
+
img_dict = {}
|
155 |
+
img_dict["id"] = image_id
|
156 |
+
img_dict["height"] = img.shape[-2]
|
157 |
+
img_dict["width"] = img.shape[-1]
|
158 |
+
dataset["images"].append(img_dict)
|
159 |
+
bboxes = targets["boxes"]
|
160 |
+
bboxes[:, 2:] -= bboxes[:, :2]
|
161 |
+
bboxes = bboxes.tolist()
|
162 |
+
labels = targets["labels"].tolist()
|
163 |
+
areas = targets["area"].tolist()
|
164 |
+
iscrowd = targets["iscrowd"].tolist()
|
165 |
+
if "masks" in targets:
|
166 |
+
masks = targets["masks"]
|
167 |
+
# make masks Fortran contiguous for coco_mask
|
168 |
+
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
169 |
+
if "keypoints" in targets:
|
170 |
+
keypoints = targets["keypoints"]
|
171 |
+
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
|
172 |
+
num_objs = len(bboxes)
|
173 |
+
for i in range(num_objs):
|
174 |
+
ann = {}
|
175 |
+
ann["image_id"] = image_id
|
176 |
+
ann["bbox"] = bboxes[i]
|
177 |
+
ann["category_id"] = labels[i]
|
178 |
+
categories.add(labels[i])
|
179 |
+
ann["area"] = areas[i]
|
180 |
+
ann["iscrowd"] = iscrowd[i]
|
181 |
+
ann["id"] = ann_id
|
182 |
+
if "masks" in targets:
|
183 |
+
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
|
184 |
+
if "keypoints" in targets:
|
185 |
+
ann["keypoints"] = keypoints[i]
|
186 |
+
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
|
187 |
+
dataset["annotations"].append(ann)
|
188 |
+
ann_id += 1
|
189 |
+
dataset["categories"] = [{"id": i} for i in sorted(categories)]
|
190 |
+
coco_ds.dataset = dataset
|
191 |
+
coco_ds.createIndex()
|
192 |
+
return coco_ds
|
193 |
+
|
194 |
+
|
195 |
+
def get_coco_api_from_dataset(dataset):
|
196 |
+
for _ in range(10):
|
197 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
198 |
+
break
|
199 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
200 |
+
dataset = dataset.dataset
|
201 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
202 |
+
return dataset.coco
|
203 |
+
return convert_to_coco_api(dataset)
|
204 |
+
|
205 |
+
|
206 |
+
class CocoDetection(torchvision.datasets.CocoDetection):
|
207 |
+
def __init__(self, img_folder, ann_file, transforms):
|
208 |
+
super().__init__(img_folder, ann_file)
|
209 |
+
self._transforms = transforms
|
210 |
+
|
211 |
+
def __getitem__(self, idx):
|
212 |
+
img, target = super().__getitem__(idx)
|
213 |
+
image_id = self.ids[idx]
|
214 |
+
target = dict(image_id=image_id, annotations=target)
|
215 |
+
if self._transforms is not None:
|
216 |
+
img, target = self._transforms(img, target)
|
217 |
+
return img, target
|
218 |
+
|
219 |
+
|
220 |
+
def get_coco(root, image_set, transforms, mode="instances"):
|
221 |
+
anno_file_template = "{}_{}2017.json"
|
222 |
+
PATHS = {
|
223 |
+
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
|
224 |
+
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
|
225 |
+
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
|
226 |
+
}
|
227 |
+
|
228 |
+
t = [ConvertCocoPolysToMask()]
|
229 |
+
|
230 |
+
if transforms is not None:
|
231 |
+
t.append(transforms)
|
232 |
+
transforms = T.Compose(t)
|
233 |
+
|
234 |
+
img_folder, ann_file = PATHS[image_set]
|
235 |
+
img_folder = os.path.join(root, img_folder)
|
236 |
+
ann_file = os.path.join(root, ann_file)
|
237 |
+
|
238 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
|
239 |
+
|
240 |
+
if image_set == "train":
|
241 |
+
dataset = _coco_remove_images_without_annotations(dataset)
|
242 |
+
|
243 |
+
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
|
244 |
+
|
245 |
+
return dataset
|
246 |
+
|
247 |
+
|
248 |
+
def get_coco_kp(root, image_set, transforms):
|
249 |
+
return get_coco(root, image_set, transforms, mode="person_keypoints")
|
DenseMammogram/detection/engine.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision.models.detection.mask_rcnn
|
7 |
+
import detection.utils as utils
|
8 |
+
from detection.coco_eval import CocoEvaluator
|
9 |
+
from detection.coco_utils import get_coco_api_from_dataset
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
sys.path.append("..")
|
15 |
+
from utils import AverageMeter
|
16 |
+
from advanced_logger import LogPriority
|
17 |
+
|
18 |
+
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
|
19 |
+
model.train()
|
20 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
21 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
22 |
+
header = f"Epoch: [{epoch}]"
|
23 |
+
|
24 |
+
lr_scheduler = None
|
25 |
+
if epoch == 0:
|
26 |
+
warmup_factor = 1.0 / 1000
|
27 |
+
warmup_iters = min(1000, len(data_loader) - 1)
|
28 |
+
|
29 |
+
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
30 |
+
optimizer, start_factor=warmup_factor, total_iters=warmup_iters
|
31 |
+
)
|
32 |
+
#for batch_idx,(images, targets) in enumerate(tqdm(data_loader)):
|
33 |
+
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
|
34 |
+
#print(images.shape)
|
35 |
+
images = list(image.to(device) if len(image)>2 else [image[0].to(device),image[1].to(device)] for image in images)
|
36 |
+
#print(len(images))
|
37 |
+
#print(images[0].shape)
|
38 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
39 |
+
with torch.cuda.amp.autocast(enabled=scaler is not None):
|
40 |
+
loss_dict = model(images, targets)
|
41 |
+
losses = sum(loss for loss in loss_dict.values())
|
42 |
+
|
43 |
+
# reduce losses over all GPUs for logging purposes
|
44 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
45 |
+
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
46 |
+
|
47 |
+
loss_value = losses_reduced.item()
|
48 |
+
|
49 |
+
if not math.isfinite(loss_value):
|
50 |
+
print(f"Loss is {loss_value}, stopping training")
|
51 |
+
print(loss_dict_reduced)
|
52 |
+
sys.exit(1)
|
53 |
+
|
54 |
+
optimizer.zero_grad()
|
55 |
+
if scaler is not None:
|
56 |
+
scaler.scale(losses).backward()
|
57 |
+
scaler.step(optimizer)
|
58 |
+
scaler.update()
|
59 |
+
else:
|
60 |
+
losses.backward()
|
61 |
+
optimizer.step()
|
62 |
+
|
63 |
+
if lr_scheduler is not None:
|
64 |
+
lr_scheduler.step()
|
65 |
+
|
66 |
+
#if(batch_idx%20==0):
|
67 |
+
# print('epoch {} batch {} : {}'.format(epoch,batch_idx,losses_reduced))
|
68 |
+
|
69 |
+
metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
|
70 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
71 |
+
|
72 |
+
return metric_logger
|
73 |
+
|
74 |
+
|
75 |
+
def train_one_epoch_simplified(model, optimizer, data_loader, device, epoch, experimenter,optimizer_backbone=None):
|
76 |
+
|
77 |
+
model.train()
|
78 |
+
lr_scheduler = None
|
79 |
+
lr_scheduler_backbone = None
|
80 |
+
if epoch == 0:
|
81 |
+
warmup_factor = 1.0 / 1000
|
82 |
+
warmup_iters = min(1000, len(data_loader) - 1)
|
83 |
+
|
84 |
+
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
85 |
+
optimizer, start_factor=warmup_factor, total_iters=warmup_iters
|
86 |
+
)
|
87 |
+
if(optimizer_backbone is not None):
|
88 |
+
lr_scheduler_backbone = torch.optim.lr_scheduler.LinearLR(optimizer_backbone, start_factor=warmup_factor, total_iters=warmup_iters)
|
89 |
+
|
90 |
+
|
91 |
+
loss_meter = AverageMeter()
|
92 |
+
|
93 |
+
for step, (images, targets) in enumerate(tqdm(data_loader)):
|
94 |
+
|
95 |
+
optimizer.zero_grad()
|
96 |
+
if(optimizer_backbone is not None):
|
97 |
+
optimizer_backbone.zero_grad()
|
98 |
+
|
99 |
+
images = list(image.to(device) if len(image)>2 else [image[0].to(device),image[1].to(device)] for image in images)
|
100 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
101 |
+
loss_dict = model(images, targets)
|
102 |
+
losses = sum(loss for loss in loss_dict.values())
|
103 |
+
|
104 |
+
|
105 |
+
if not math.isfinite(losses.item()):
|
106 |
+
print(f"Loss is {losses.item()}, stopping training")
|
107 |
+
print(loss_dict)
|
108 |
+
experimenter.log(f"Loss is {losses.item()}, stopping training")
|
109 |
+
sys.exit(1)
|
110 |
+
|
111 |
+
losses.backward()
|
112 |
+
loss_meter.update(losses.item())
|
113 |
+
optimizer.step()
|
114 |
+
if optimizer_backbone is not None:
|
115 |
+
optimizer_backbone.step()
|
116 |
+
if lr_scheduler is not None:
|
117 |
+
lr_scheduler.step()
|
118 |
+
if lr_scheduler_backbone is not None:
|
119 |
+
lr_scheduler_backbone.step()
|
120 |
+
|
121 |
+
if (step+1)%10 == 0:
|
122 |
+
experimenter.log('Loss after {} steps: {}'.format(step+1, loss_meter.avg))
|
123 |
+
if epoch == 0 and (step+1)%50 == 0:
|
124 |
+
experimenter.log('LR after {} steps: {}'.format(step+1, optimizer.param_groups[0]['lr']))
|
125 |
+
|
126 |
+
def _get_iou_types(model):
|
127 |
+
model_without_ddp = model
|
128 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
129 |
+
model_without_ddp = model.module
|
130 |
+
iou_types = ["bbox"]
|
131 |
+
if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
|
132 |
+
iou_types.append("segm")
|
133 |
+
if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
|
134 |
+
iou_types.append("keypoints")
|
135 |
+
return iou_types
|
136 |
+
|
137 |
+
|
138 |
+
@torch.inference_mode()
|
139 |
+
def evaluate(model, data_loader, device):
|
140 |
+
n_threads = torch.get_num_threads()
|
141 |
+
# FIXME remove this and make paste_masks_in_image run on the GPU
|
142 |
+
torch.set_num_threads(1)
|
143 |
+
cpu_device = torch.device("cpu")
|
144 |
+
model.eval()
|
145 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
146 |
+
header = "Test:"
|
147 |
+
|
148 |
+
coco = get_coco_api_from_dataset(data_loader.dataset)
|
149 |
+
iou_types = _get_iou_types(model)
|
150 |
+
coco_evaluator = CocoEvaluator(coco, iou_types)
|
151 |
+
|
152 |
+
for images, targets in metric_logger.log_every(data_loader, 100, header):
|
153 |
+
images = list(img.to(device) for img in images)
|
154 |
+
|
155 |
+
if torch.cuda.is_available():
|
156 |
+
torch.cuda.synchronize()
|
157 |
+
model_time = time.time()
|
158 |
+
outputs = model(images)
|
159 |
+
|
160 |
+
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
|
161 |
+
model_time = time.time() - model_time
|
162 |
+
|
163 |
+
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
|
164 |
+
evaluator_time = time.time()
|
165 |
+
coco_evaluator.update(res)
|
166 |
+
evaluator_time = time.time() - evaluator_time
|
167 |
+
metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
|
168 |
+
|
169 |
+
# gather the stats from all processes
|
170 |
+
metric_logger.synchronize_between_processes()
|
171 |
+
print("Averaged stats:", metric_logger)
|
172 |
+
coco_evaluator.synchronize_between_processes()
|
173 |
+
|
174 |
+
# accumulate predictions from all images
|
175 |
+
coco_evaluator.accumulate()
|
176 |
+
coco_evaluator.summarize()
|
177 |
+
torch.set_num_threads(n_threads)
|
178 |
+
return coco_evaluator
|
179 |
+
|
180 |
+
|
181 |
+
def coco_summ(coco_eval, experimenter):
|
182 |
+
self = coco_eval
|
183 |
+
def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
|
184 |
+
p = self.params
|
185 |
+
iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
|
186 |
+
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
|
187 |
+
typeStr = '(AP)' if ap==1 else '(AR)'
|
188 |
+
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
|
189 |
+
if iouThr is None else '{:0.2f}'.format(iouThr)
|
190 |
+
|
191 |
+
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
|
192 |
+
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
193 |
+
if ap == 1:
|
194 |
+
# dimension of precision: [TxRxKxAxM]
|
195 |
+
s = self.eval['precision']
|
196 |
+
# IoU
|
197 |
+
if iouThr is not None:
|
198 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
199 |
+
s = s[t]
|
200 |
+
s = s[:,:,:,aind,mind]
|
201 |
+
else:
|
202 |
+
# dimension of recall: [TxKxAxM]
|
203 |
+
s = self.eval['recall']
|
204 |
+
if iouThr is not None:
|
205 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
206 |
+
s = s[t]
|
207 |
+
s = s[:,:,aind,mind]
|
208 |
+
if len(s[s>-1])==0:
|
209 |
+
mean_s = -1
|
210 |
+
else:
|
211 |
+
mean_s = np.mean(s[s>-1])
|
212 |
+
experimenter.log(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s), priority = LogPriority.MEDIUM)
|
213 |
+
return mean_s
|
214 |
+
def _summarizeDets():
|
215 |
+
stats = np.zeros((12,))
|
216 |
+
stats[0] = _summarize(1)
|
217 |
+
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
|
218 |
+
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
|
219 |
+
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
|
220 |
+
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
|
221 |
+
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
|
222 |
+
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
|
223 |
+
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
|
224 |
+
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
|
225 |
+
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
|
226 |
+
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
|
227 |
+
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
|
228 |
+
return stats
|
229 |
+
_summarizeDets()
|
230 |
+
|
231 |
+
@torch.inference_mode()
|
232 |
+
def evaluate_simplified(model, data_loader, device, experimenter):
|
233 |
+
cpu_device = torch.device("cpu")
|
234 |
+
model.eval()
|
235 |
+
experimenter.log('Evaluating Validation Parameters')
|
236 |
+
|
237 |
+
coco = get_coco_api_from_dataset(data_loader.dataset)
|
238 |
+
iou_types = _get_iou_types(model)
|
239 |
+
coco_evaluator = CocoEvaluator(coco, iou_types)
|
240 |
+
|
241 |
+
for images, targets in data_loader:
|
242 |
+
images = list(img.to(device) for img in images)
|
243 |
+
|
244 |
+
if torch.cuda.is_available():
|
245 |
+
torch.cuda.synchronize()
|
246 |
+
outputs = model(images)
|
247 |
+
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
|
248 |
+
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
|
249 |
+
coco_evaluator.update(res)
|
250 |
+
|
251 |
+
# gather the stats from all processes
|
252 |
+
coco_evaluator.synchronize_between_processes()
|
253 |
+
|
254 |
+
# accumulate predictions from all images
|
255 |
+
coco_evaluator.accumulate()
|
256 |
+
|
257 |
+
# Debug and see what all info it has
|
258 |
+
# coco_evaluator.summarize()
|
259 |
+
for iou_type, coco_eval in coco_evaluator.coco_eval.items():
|
260 |
+
print(f"IoU metric: {iou_type}")
|
261 |
+
coco_summ(coco_eval, experimenter)
|
262 |
+
|
263 |
+
return coco_evaluator
|
264 |
+
|
265 |
+
def evaluate_loss(model, device, val_loader, experimenter=None):
|
266 |
+
model.train()
|
267 |
+
#experimenter.log('Evaluating Validation Loss')
|
268 |
+
with torch.no_grad():
|
269 |
+
loss_meter = AverageMeter()
|
270 |
+
for images, targets in tqdm(val_loader):
|
271 |
+
images = list(image.to(device) if len(image)>2 else [image[0].to(device),image[1].to(device)] for image in images)
|
272 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
273 |
+
loss_dict = model(images, targets)
|
274 |
+
losses = sum(loss for loss in loss_dict.values())
|
275 |
+
loss_meter.update(losses.item())
|
276 |
+
return loss_meter.avg
|
DenseMammogram/detection/group_by_aspect_ratio.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
from collections import defaultdict
|
5 |
+
from itertools import repeat, chain
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
import torchvision
|
11 |
+
from PIL import Image
|
12 |
+
from torch.utils.data.sampler import BatchSampler, Sampler
|
13 |
+
from torch.utils.model_zoo import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def _repeat_to_at_least(iterable, n):
|
17 |
+
repeat_times = math.ceil(n / len(iterable))
|
18 |
+
repeated = chain.from_iterable(repeat(iterable, repeat_times))
|
19 |
+
return list(repeated)
|
20 |
+
|
21 |
+
|
22 |
+
class GroupedBatchSampler(BatchSampler):
|
23 |
+
"""
|
24 |
+
Wraps another sampler to yield a mini-batch of indices.
|
25 |
+
It enforces that the batch only contain elements from the same group.
|
26 |
+
It also tries to provide mini-batches which follows an ordering which is
|
27 |
+
as close as possible to the ordering from the original sampler.
|
28 |
+
Args:
|
29 |
+
sampler (Sampler): Base sampler.
|
30 |
+
group_ids (list[int]): If the sampler produces indices in range [0, N),
|
31 |
+
`group_ids` must be a list of `N` ints which contains the group id of each sample.
|
32 |
+
The group ids must be a continuous set of integers starting from
|
33 |
+
0, i.e. they must be in the range [0, num_groups).
|
34 |
+
batch_size (int): Size of mini-batch.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, sampler, group_ids, batch_size):
|
38 |
+
if not isinstance(sampler, Sampler):
|
39 |
+
raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
|
40 |
+
self.sampler = sampler
|
41 |
+
self.group_ids = group_ids
|
42 |
+
self.batch_size = batch_size
|
43 |
+
|
44 |
+
def __iter__(self):
|
45 |
+
buffer_per_group = defaultdict(list)
|
46 |
+
samples_per_group = defaultdict(list)
|
47 |
+
|
48 |
+
num_batches = 0
|
49 |
+
for idx in self.sampler:
|
50 |
+
group_id = self.group_ids[idx]
|
51 |
+
buffer_per_group[group_id].append(idx)
|
52 |
+
samples_per_group[group_id].append(idx)
|
53 |
+
if len(buffer_per_group[group_id]) == self.batch_size:
|
54 |
+
yield buffer_per_group[group_id]
|
55 |
+
num_batches += 1
|
56 |
+
del buffer_per_group[group_id]
|
57 |
+
assert len(buffer_per_group[group_id]) < self.batch_size
|
58 |
+
|
59 |
+
# now we have run out of elements that satisfy
|
60 |
+
# the group criteria, let's return the remaining
|
61 |
+
# elements so that the size of the sampler is
|
62 |
+
# deterministic
|
63 |
+
expected_num_batches = len(self)
|
64 |
+
num_remaining = expected_num_batches - num_batches
|
65 |
+
if num_remaining > 0:
|
66 |
+
# for the remaining batches, take first the buffers with largest number
|
67 |
+
# of elements
|
68 |
+
for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True):
|
69 |
+
remaining = self.batch_size - len(buffer_per_group[group_id])
|
70 |
+
samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
|
71 |
+
buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
|
72 |
+
assert len(buffer_per_group[group_id]) == self.batch_size
|
73 |
+
yield buffer_per_group[group_id]
|
74 |
+
num_remaining -= 1
|
75 |
+
if num_remaining == 0:
|
76 |
+
break
|
77 |
+
assert num_remaining == 0
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.sampler) // self.batch_size
|
81 |
+
|
82 |
+
|
83 |
+
def _compute_aspect_ratios_slow(dataset, indices=None):
|
84 |
+
print(
|
85 |
+
"Your dataset doesn't support the fast path for "
|
86 |
+
"computing the aspect ratios, so will iterate over "
|
87 |
+
"the full dataset and load every image instead. "
|
88 |
+
"This might take some time..."
|
89 |
+
)
|
90 |
+
if indices is None:
|
91 |
+
indices = range(len(dataset))
|
92 |
+
|
93 |
+
class SubsetSampler(Sampler):
|
94 |
+
def __init__(self, indices):
|
95 |
+
self.indices = indices
|
96 |
+
|
97 |
+
def __iter__(self):
|
98 |
+
return iter(self.indices)
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.indices)
|
102 |
+
|
103 |
+
sampler = SubsetSampler(indices)
|
104 |
+
data_loader = torch.utils.data.DataLoader(
|
105 |
+
dataset,
|
106 |
+
batch_size=1,
|
107 |
+
sampler=sampler,
|
108 |
+
num_workers=14, # you might want to increase it for faster processing
|
109 |
+
collate_fn=lambda x: x[0],
|
110 |
+
)
|
111 |
+
aspect_ratios = []
|
112 |
+
with tqdm(total=len(dataset)) as pbar:
|
113 |
+
for _i, (img, _) in enumerate(data_loader):
|
114 |
+
pbar.update(1)
|
115 |
+
height, width = img.shape[-2:]
|
116 |
+
aspect_ratio = float(width) / float(height)
|
117 |
+
aspect_ratios.append(aspect_ratio)
|
118 |
+
return aspect_ratios
|
119 |
+
|
120 |
+
|
121 |
+
def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
|
122 |
+
if indices is None:
|
123 |
+
indices = range(len(dataset))
|
124 |
+
aspect_ratios = []
|
125 |
+
for i in indices:
|
126 |
+
height, width = dataset.get_height_and_width(i)
|
127 |
+
aspect_ratio = float(width) / float(height)
|
128 |
+
aspect_ratios.append(aspect_ratio)
|
129 |
+
return aspect_ratios
|
130 |
+
|
131 |
+
|
132 |
+
def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
|
133 |
+
if indices is None:
|
134 |
+
indices = range(len(dataset))
|
135 |
+
aspect_ratios = []
|
136 |
+
for i in indices:
|
137 |
+
img_info = dataset.coco.imgs[dataset.ids[i]]
|
138 |
+
aspect_ratio = float(img_info["width"]) / float(img_info["height"])
|
139 |
+
aspect_ratios.append(aspect_ratio)
|
140 |
+
return aspect_ratios
|
141 |
+
|
142 |
+
|
143 |
+
def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
|
144 |
+
if indices is None:
|
145 |
+
indices = range(len(dataset))
|
146 |
+
aspect_ratios = []
|
147 |
+
for i in indices:
|
148 |
+
# this doesn't load the data into memory, because PIL loads it lazily
|
149 |
+
width, height = Image.open(dataset.images[i]).size
|
150 |
+
aspect_ratio = float(width) / float(height)
|
151 |
+
aspect_ratios.append(aspect_ratio)
|
152 |
+
return aspect_ratios
|
153 |
+
|
154 |
+
|
155 |
+
def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
|
156 |
+
if indices is None:
|
157 |
+
indices = range(len(dataset))
|
158 |
+
|
159 |
+
ds_indices = [dataset.indices[i] for i in indices]
|
160 |
+
return compute_aspect_ratios(dataset.dataset, ds_indices)
|
161 |
+
|
162 |
+
|
163 |
+
def compute_aspect_ratios(dataset, indices=None):
|
164 |
+
if hasattr(dataset, "get_height_and_width"):
|
165 |
+
return _compute_aspect_ratios_custom_dataset(dataset, indices)
|
166 |
+
|
167 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
168 |
+
return _compute_aspect_ratios_coco_dataset(dataset, indices)
|
169 |
+
|
170 |
+
if isinstance(dataset, torchvision.datasets.VOCDetection):
|
171 |
+
return _compute_aspect_ratios_voc_dataset(dataset, indices)
|
172 |
+
|
173 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
174 |
+
return _compute_aspect_ratios_subset_dataset(dataset, indices)
|
175 |
+
|
176 |
+
# slow path
|
177 |
+
return _compute_aspect_ratios_slow(dataset, indices)
|
178 |
+
|
179 |
+
|
180 |
+
def _quantize(x, bins):
|
181 |
+
bins = copy.deepcopy(bins)
|
182 |
+
bins = sorted(bins)
|
183 |
+
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
|
184 |
+
return quantized
|
185 |
+
|
186 |
+
|
187 |
+
def create_aspect_ratio_groups(dataset, k=0):
|
188 |
+
aspect_ratios = compute_aspect_ratios(dataset)
|
189 |
+
bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
|
190 |
+
groups = _quantize(aspect_ratios, bins)
|
191 |
+
# count number of elements per group
|
192 |
+
counts = np.unique(groups, return_counts=True)[1]
|
193 |
+
fbins = [0] + bins + [np.inf]
|
194 |
+
print(f"Using {fbins} as bins for aspect ratio quantization")
|
195 |
+
print(f"Count of instances per bin: {counts}")
|
196 |
+
return groups
|
DenseMammogram/detection/presets.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import detection.transforms as T
|
3 |
+
|
4 |
+
|
5 |
+
class DetectionPresetTrain:
|
6 |
+
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
|
7 |
+
if data_augmentation == "hflip":
|
8 |
+
self.transforms = T.Compose(
|
9 |
+
[
|
10 |
+
T.RandomHorizontalFlip(p=hflip_prob),
|
11 |
+
T.PILToTensor(),
|
12 |
+
T.ConvertImageDtype(torch.float),
|
13 |
+
]
|
14 |
+
)
|
15 |
+
elif data_augmentation == "ssd":
|
16 |
+
self.transforms = T.Compose(
|
17 |
+
[
|
18 |
+
T.RandomPhotometricDistort(),
|
19 |
+
T.RandomZoomOut(fill=list(mean)),
|
20 |
+
T.RandomIoUCrop(),
|
21 |
+
T.RandomHorizontalFlip(p=hflip_prob),
|
22 |
+
T.PILToTensor(),
|
23 |
+
T.ConvertImageDtype(torch.float),
|
24 |
+
]
|
25 |
+
)
|
26 |
+
elif data_augmentation == "ssdlite":
|
27 |
+
self.transforms = T.Compose(
|
28 |
+
[
|
29 |
+
T.RandomIoUCrop(),
|
30 |
+
T.RandomHorizontalFlip(p=hflip_prob),
|
31 |
+
T.PILToTensor(),
|
32 |
+
T.ConvertImageDtype(torch.float),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
|
37 |
+
|
38 |
+
def __call__(self, img, target):
|
39 |
+
return self.transforms(img, target)
|
40 |
+
|
41 |
+
|
42 |
+
class DetectionPresetEval:
|
43 |
+
def __init__(self):
|
44 |
+
self.transforms = T.ToTensor()
|
45 |
+
|
46 |
+
def __call__(self, img, target):
|
47 |
+
return self.transforms(img, target)
|
DenseMammogram/detection/train.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""PyTorch Detection Training.
|
2 |
+
|
3 |
+
To run in a multi-gpu environment, use the distributed launcher::
|
4 |
+
|
5 |
+
python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
|
6 |
+
train.py ... --world-size $NGPU
|
7 |
+
|
8 |
+
The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu.
|
9 |
+
--lr 0.02 --batch-size 2 --world-size 8
|
10 |
+
If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU.
|
11 |
+
|
12 |
+
On top of that, for training Faster/Mask R-CNN, the default hyperparameters are
|
13 |
+
--epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3
|
14 |
+
|
15 |
+
Also, if you train Keypoint R-CNN, the default hyperparameters are
|
16 |
+
--epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3
|
17 |
+
Because the number of images is smaller in the person keypoint subset of COCO,
|
18 |
+
the number of epochs should be adapted so that we have the same number of iterations.
|
19 |
+
"""
|
20 |
+
import datetime
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
|
24 |
+
import detection.presets
|
25 |
+
import torch
|
26 |
+
import torch.utils.data
|
27 |
+
import torchvision
|
28 |
+
import torchvision.models.detection
|
29 |
+
import torchvision.models.detection.mask_rcnn
|
30 |
+
import detection.utils as utils
|
31 |
+
from detection.coco_utils import get_coco, get_coco_kp
|
32 |
+
from detection.engine import train_one_epoch, evaluate
|
33 |
+
from detection.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from torchvision.prototype import models as PM
|
38 |
+
except ImportError:
|
39 |
+
PM = None
|
40 |
+
|
41 |
+
|
42 |
+
def get_dataset(name, image_set, transform, data_path):
|
43 |
+
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
|
44 |
+
p, ds_fn, num_classes = paths[name]
|
45 |
+
|
46 |
+
ds = ds_fn(p, image_set=image_set, transforms=transform)
|
47 |
+
return ds, num_classes
|
48 |
+
|
49 |
+
|
50 |
+
def get_transform(train, args):
|
51 |
+
if train:
|
52 |
+
return presets.DetectionPresetTrain(args.data_augmentation)
|
53 |
+
elif not args.weights:
|
54 |
+
return presets.DetectionPresetEval()
|
55 |
+
else:
|
56 |
+
weights = PM.get_weight(args.weights)
|
57 |
+
return weights.transforms()
|
58 |
+
|
59 |
+
|
60 |
+
def get_args_parser(add_help=True):
|
61 |
+
import argparse
|
62 |
+
|
63 |
+
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
|
64 |
+
|
65 |
+
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
|
66 |
+
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
|
67 |
+
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
|
68 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
69 |
+
parser.add_argument(
|
70 |
+
"-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
71 |
+
)
|
72 |
+
parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run")
|
73 |
+
parser.add_argument(
|
74 |
+
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--lr",
|
78 |
+
default=0.02,
|
79 |
+
type=float,
|
80 |
+
help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu",
|
81 |
+
)
|
82 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
83 |
+
parser.add_argument(
|
84 |
+
"--wd",
|
85 |
+
"--weight-decay",
|
86 |
+
default=1e-4,
|
87 |
+
type=float,
|
88 |
+
metavar="W",
|
89 |
+
help="weight decay (default: 1e-4)",
|
90 |
+
dest="weight_decay",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)"
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--lr-steps",
|
100 |
+
default=[16, 22],
|
101 |
+
nargs="+",
|
102 |
+
type=int,
|
103 |
+
help="decrease lr every step-size epochs (multisteplr scheduler only)",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)"
|
107 |
+
)
|
108 |
+
parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
|
109 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
110 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
111 |
+
parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
|
112 |
+
parser.add_argument("--aspect-ratio-group-factor", default=3, type=int)
|
113 |
+
parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn")
|
114 |
+
parser.add_argument(
|
115 |
+
"--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone"
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)"
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--sync-bn",
|
122 |
+
dest="sync_bn",
|
123 |
+
help="Use sync batch norm",
|
124 |
+
action="store_true",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--test-only",
|
128 |
+
dest="test_only",
|
129 |
+
help="Only test the model",
|
130 |
+
action="store_true",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--pretrained",
|
134 |
+
dest="pretrained",
|
135 |
+
help="Use pre-trained models from the modelzoo",
|
136 |
+
action="store_true",
|
137 |
+
)
|
138 |
+
|
139 |
+
# distributed training parameters
|
140 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
141 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
142 |
+
|
143 |
+
# Prototype models only
|
144 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
145 |
+
|
146 |
+
# Mixed precision training parameters
|
147 |
+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
|
148 |
+
|
149 |
+
return parser
|
150 |
+
|
151 |
+
|
152 |
+
def main(args):
|
153 |
+
if args.weights and PM is None:
|
154 |
+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
|
155 |
+
if args.output_dir:
|
156 |
+
utils.mkdir(args.output_dir)
|
157 |
+
|
158 |
+
utils.init_distributed_mode(args)
|
159 |
+
print(args)
|
160 |
+
|
161 |
+
device = torch.device(args.device)
|
162 |
+
|
163 |
+
# Data loading code
|
164 |
+
print("Loading data")
|
165 |
+
|
166 |
+
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path)
|
167 |
+
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path)
|
168 |
+
|
169 |
+
print("Creating data loaders")
|
170 |
+
if args.distributed:
|
171 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
172 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
|
173 |
+
else:
|
174 |
+
train_sampler = torch.utils.data.RandomSampler(dataset)
|
175 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
176 |
+
|
177 |
+
if args.aspect_ratio_group_factor >= 0:
|
178 |
+
group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
|
179 |
+
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
|
180 |
+
else:
|
181 |
+
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
|
182 |
+
|
183 |
+
data_loader = torch.utils.data.DataLoader(
|
184 |
+
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
|
185 |
+
)
|
186 |
+
|
187 |
+
data_loader_test = torch.utils.data.DataLoader(
|
188 |
+
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
|
189 |
+
)
|
190 |
+
|
191 |
+
print("Creating model")
|
192 |
+
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
|
193 |
+
if "rcnn" in args.model:
|
194 |
+
if args.rpn_score_thresh is not None:
|
195 |
+
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
|
196 |
+
if not args.weights:
|
197 |
+
model = torchvision.models.detection.__dict__[args.model](
|
198 |
+
pretrained=args.pretrained, num_classes=num_classes, **kwargs
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
|
202 |
+
model.to(device)
|
203 |
+
if args.distributed and args.sync_bn:
|
204 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
205 |
+
|
206 |
+
model_without_ddp = model
|
207 |
+
if args.distributed:
|
208 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
209 |
+
model_without_ddp = model.module
|
210 |
+
|
211 |
+
params = [p for p in model.parameters() if p.requires_grad]
|
212 |
+
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
213 |
+
|
214 |
+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
|
215 |
+
|
216 |
+
args.lr_scheduler = args.lr_scheduler.lower()
|
217 |
+
if args.lr_scheduler == "multisteplr":
|
218 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
|
219 |
+
elif args.lr_scheduler == "cosineannealinglr":
|
220 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
221 |
+
else:
|
222 |
+
raise RuntimeError(
|
223 |
+
f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
|
224 |
+
)
|
225 |
+
|
226 |
+
if args.resume:
|
227 |
+
checkpoint = torch.load(args.resume, map_location="cpu")
|
228 |
+
model_without_ddp.load_state_dict(checkpoint["model"])
|
229 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
230 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
231 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
232 |
+
if args.amp:
|
233 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
234 |
+
|
235 |
+
if args.test_only:
|
236 |
+
evaluate(model, data_loader_test, device=device)
|
237 |
+
return
|
238 |
+
|
239 |
+
print("Start training")
|
240 |
+
start_time = time.time()
|
241 |
+
for epoch in range(args.start_epoch, args.epochs):
|
242 |
+
if args.distributed:
|
243 |
+
train_sampler.set_epoch(epoch)
|
244 |
+
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
|
245 |
+
lr_scheduler.step()
|
246 |
+
if args.output_dir:
|
247 |
+
checkpoint = {
|
248 |
+
"model": model_without_ddp.state_dict(),
|
249 |
+
"optimizer": optimizer.state_dict(),
|
250 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
251 |
+
"args": args,
|
252 |
+
"epoch": epoch,
|
253 |
+
}
|
254 |
+
if args.amp:
|
255 |
+
checkpoint["scaler"] = scaler.state_dict()
|
256 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
257 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
258 |
+
|
259 |
+
# evaluate after every epoch
|
260 |
+
evaluate(model, data_loader_test, device=device)
|
261 |
+
|
262 |
+
total_time = time.time() - start_time
|
263 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
264 |
+
print(f"Training time {total_time_str}")
|
265 |
+
|
266 |
+
|
267 |
+
if __name__ == "__main__":
|
268 |
+
args = get_args_parser().parse_args()
|
269 |
+
main(args)
|
DenseMammogram/detection/transforms.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
from torch import nn, Tensor
|
6 |
+
from torchvision.transforms import functional as F
|
7 |
+
from torchvision.transforms import transforms as T
|
8 |
+
|
9 |
+
|
10 |
+
def _flip_coco_person_keypoints(kps, width):
|
11 |
+
flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
12 |
+
flipped_data = kps[:, flip_inds]
|
13 |
+
flipped_data[..., 0] = width - flipped_data[..., 0]
|
14 |
+
# Maintain COCO convention that if visibility == 0, then x, y = 0
|
15 |
+
inds = flipped_data[..., 2] == 0
|
16 |
+
flipped_data[inds] = 0
|
17 |
+
return flipped_data
|
18 |
+
|
19 |
+
|
20 |
+
class Compose:
|
21 |
+
def __init__(self, transforms):
|
22 |
+
self.transforms = transforms
|
23 |
+
|
24 |
+
def __call__(self, image, target):
|
25 |
+
for t in self.transforms:
|
26 |
+
image, target = t(image, target)
|
27 |
+
return image, target
|
28 |
+
|
29 |
+
|
30 |
+
class RandomHorizontalFlip(T.RandomHorizontalFlip):
|
31 |
+
def forward(
|
32 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
33 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
34 |
+
if torch.rand(1) < self.p:
|
35 |
+
image = F.hflip(image)
|
36 |
+
if target is not None:
|
37 |
+
width, _ = F.get_image_size(image)
|
38 |
+
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
|
39 |
+
if "masks" in target:
|
40 |
+
target["masks"] = target["masks"].flip(-1)
|
41 |
+
if "keypoints" in target:
|
42 |
+
keypoints = target["keypoints"]
|
43 |
+
keypoints = _flip_coco_person_keypoints(keypoints, width)
|
44 |
+
target["keypoints"] = keypoints
|
45 |
+
return image, target
|
46 |
+
|
47 |
+
|
48 |
+
class ToTensor(nn.Module):
|
49 |
+
def forward(
|
50 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
51 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
52 |
+
image = F.pil_to_tensor(image)
|
53 |
+
image = F.convert_image_dtype(image)
|
54 |
+
return image, target
|
55 |
+
|
56 |
+
|
57 |
+
class PILToTensor(nn.Module):
|
58 |
+
def forward(
|
59 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
60 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
61 |
+
image = F.pil_to_tensor(image)
|
62 |
+
return image, target
|
63 |
+
|
64 |
+
|
65 |
+
class ConvertImageDtype(nn.Module):
|
66 |
+
def __init__(self, dtype: torch.dtype) -> None:
|
67 |
+
super().__init__()
|
68 |
+
self.dtype = dtype
|
69 |
+
|
70 |
+
def forward(
|
71 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
72 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
73 |
+
image = F.convert_image_dtype(image, self.dtype)
|
74 |
+
return image, target
|
75 |
+
|
76 |
+
|
77 |
+
class RandomIoUCrop(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
min_scale: float = 0.3,
|
81 |
+
max_scale: float = 1.0,
|
82 |
+
min_aspect_ratio: float = 0.5,
|
83 |
+
max_aspect_ratio: float = 2.0,
|
84 |
+
sampler_options: Optional[List[float]] = None,
|
85 |
+
trials: int = 40,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
|
89 |
+
self.min_scale = min_scale
|
90 |
+
self.max_scale = max_scale
|
91 |
+
self.min_aspect_ratio = min_aspect_ratio
|
92 |
+
self.max_aspect_ratio = max_aspect_ratio
|
93 |
+
if sampler_options is None:
|
94 |
+
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
|
95 |
+
self.options = sampler_options
|
96 |
+
self.trials = trials
|
97 |
+
|
98 |
+
def forward(
|
99 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
100 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
101 |
+
if target is None:
|
102 |
+
raise ValueError("The targets can't be None for this transform.")
|
103 |
+
|
104 |
+
if isinstance(image, torch.Tensor):
|
105 |
+
if image.ndimension() not in {2, 3}:
|
106 |
+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
|
107 |
+
elif image.ndimension() == 2:
|
108 |
+
image = image.unsqueeze(0)
|
109 |
+
|
110 |
+
orig_w, orig_h = F.get_image_size(image)
|
111 |
+
|
112 |
+
while True:
|
113 |
+
# sample an option
|
114 |
+
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
|
115 |
+
min_jaccard_overlap = self.options[idx]
|
116 |
+
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
|
117 |
+
return image, target
|
118 |
+
|
119 |
+
for _ in range(self.trials):
|
120 |
+
# check the aspect ratio limitations
|
121 |
+
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
|
122 |
+
new_w = int(orig_w * r[0])
|
123 |
+
new_h = int(orig_h * r[1])
|
124 |
+
aspect_ratio = new_w / new_h
|
125 |
+
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
|
126 |
+
continue
|
127 |
+
|
128 |
+
# check for 0 area crops
|
129 |
+
r = torch.rand(2)
|
130 |
+
left = int((orig_w - new_w) * r[0])
|
131 |
+
top = int((orig_h - new_h) * r[1])
|
132 |
+
right = left + new_w
|
133 |
+
bottom = top + new_h
|
134 |
+
if left == right or top == bottom:
|
135 |
+
continue
|
136 |
+
|
137 |
+
# check for any valid boxes with centers within the crop area
|
138 |
+
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
|
139 |
+
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
|
140 |
+
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
|
141 |
+
if not is_within_crop_area.any():
|
142 |
+
continue
|
143 |
+
|
144 |
+
# check at least 1 box with jaccard limitations
|
145 |
+
boxes = target["boxes"][is_within_crop_area]
|
146 |
+
ious = torchvision.ops.boxes.box_iou(
|
147 |
+
boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
|
148 |
+
)
|
149 |
+
if ious.max() < min_jaccard_overlap:
|
150 |
+
continue
|
151 |
+
|
152 |
+
# keep only valid boxes and perform cropping
|
153 |
+
target["boxes"] = boxes
|
154 |
+
target["labels"] = target["labels"][is_within_crop_area]
|
155 |
+
target["boxes"][:, 0::2] -= left
|
156 |
+
target["boxes"][:, 1::2] -= top
|
157 |
+
target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
|
158 |
+
target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
|
159 |
+
image = F.crop(image, top, left, new_h, new_w)
|
160 |
+
|
161 |
+
return image, target
|
162 |
+
|
163 |
+
|
164 |
+
class RandomZoomOut(nn.Module):
|
165 |
+
def __init__(
|
166 |
+
self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
if fill is None:
|
170 |
+
fill = [0.0, 0.0, 0.0]
|
171 |
+
self.fill = fill
|
172 |
+
self.side_range = side_range
|
173 |
+
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
|
174 |
+
raise ValueError(f"Invalid canvas side range provided {side_range}.")
|
175 |
+
self.p = p
|
176 |
+
|
177 |
+
@torch.jit.unused
|
178 |
+
def _get_fill_value(self, is_pil):
|
179 |
+
# type: (bool) -> int
|
180 |
+
# We fake the type to make it work on JIT
|
181 |
+
return tuple(int(x) for x in self.fill) if is_pil else 0
|
182 |
+
|
183 |
+
def forward(
|
184 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
185 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
186 |
+
if isinstance(image, torch.Tensor):
|
187 |
+
if image.ndimension() not in {2, 3}:
|
188 |
+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
|
189 |
+
elif image.ndimension() == 2:
|
190 |
+
image = image.unsqueeze(0)
|
191 |
+
|
192 |
+
if torch.rand(1) < self.p:
|
193 |
+
return image, target
|
194 |
+
|
195 |
+
orig_w, orig_h = F.get_image_size(image)
|
196 |
+
|
197 |
+
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
|
198 |
+
canvas_width = int(orig_w * r)
|
199 |
+
canvas_height = int(orig_h * r)
|
200 |
+
|
201 |
+
r = torch.rand(2)
|
202 |
+
left = int((canvas_width - orig_w) * r[0])
|
203 |
+
top = int((canvas_height - orig_h) * r[1])
|
204 |
+
right = canvas_width - (left + orig_w)
|
205 |
+
bottom = canvas_height - (top + orig_h)
|
206 |
+
|
207 |
+
if torch.jit.is_scripting():
|
208 |
+
fill = 0
|
209 |
+
else:
|
210 |
+
fill = self._get_fill_value(F._is_pil_image(image))
|
211 |
+
|
212 |
+
image = F.pad(image, [left, top, right, bottom], fill=fill)
|
213 |
+
if isinstance(image, torch.Tensor):
|
214 |
+
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
|
215 |
+
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
|
216 |
+
..., :, (left + orig_w) :
|
217 |
+
] = v
|
218 |
+
|
219 |
+
if target is not None:
|
220 |
+
target["boxes"][:, 0::2] += left
|
221 |
+
target["boxes"][:, 1::2] += top
|
222 |
+
|
223 |
+
return image, target
|
224 |
+
|
225 |
+
|
226 |
+
class RandomPhotometricDistort(nn.Module):
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
contrast: Tuple[float] = (0.5, 1.5),
|
230 |
+
saturation: Tuple[float] = (0.5, 1.5),
|
231 |
+
hue: Tuple[float] = (-0.05, 0.05),
|
232 |
+
brightness: Tuple[float] = (0.875, 1.125),
|
233 |
+
p: float = 0.5,
|
234 |
+
):
|
235 |
+
super().__init__()
|
236 |
+
self._brightness = T.ColorJitter(brightness=brightness)
|
237 |
+
self._contrast = T.ColorJitter(contrast=contrast)
|
238 |
+
self._hue = T.ColorJitter(hue=hue)
|
239 |
+
self._saturation = T.ColorJitter(saturation=saturation)
|
240 |
+
self.p = p
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
244 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
245 |
+
if isinstance(image, torch.Tensor):
|
246 |
+
if image.ndimension() not in {2, 3}:
|
247 |
+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
|
248 |
+
elif image.ndimension() == 2:
|
249 |
+
image = image.unsqueeze(0)
|
250 |
+
|
251 |
+
r = torch.rand(7)
|
252 |
+
|
253 |
+
if r[0] < self.p:
|
254 |
+
image = self._brightness(image)
|
255 |
+
|
256 |
+
contrast_before = r[1] < 0.5
|
257 |
+
if contrast_before:
|
258 |
+
if r[2] < self.p:
|
259 |
+
image = self._contrast(image)
|
260 |
+
|
261 |
+
if r[3] < self.p:
|
262 |
+
image = self._saturation(image)
|
263 |
+
|
264 |
+
if r[4] < self.p:
|
265 |
+
image = self._hue(image)
|
266 |
+
|
267 |
+
if not contrast_before:
|
268 |
+
if r[5] < self.p:
|
269 |
+
image = self._contrast(image)
|
270 |
+
|
271 |
+
if r[6] < self.p:
|
272 |
+
channels = F.get_image_num_channels(image)
|
273 |
+
permutation = torch.randperm(channels)
|
274 |
+
|
275 |
+
is_pil = F._is_pil_image(image)
|
276 |
+
if is_pil:
|
277 |
+
image = F.pil_to_tensor(image)
|
278 |
+
image = F.convert_image_dtype(image)
|
279 |
+
image = image[..., permutation, :, :]
|
280 |
+
if is_pil:
|
281 |
+
image = F.to_pil_image(image)
|
282 |
+
|
283 |
+
return image, target
|
DenseMammogram/detection/utils.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import errno
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from collections import defaultdict, deque
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
|
11 |
+
class SmoothedValue:
|
12 |
+
"""Track a series of values and provide access to smoothed values over a
|
13 |
+
window or the global series average.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, window_size=20, fmt=None):
|
17 |
+
if fmt is None:
|
18 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
19 |
+
self.deque = deque(maxlen=window_size)
|
20 |
+
self.total = 0.0
|
21 |
+
self.count = 0
|
22 |
+
self.fmt = fmt
|
23 |
+
|
24 |
+
def update(self, value, n=1):
|
25 |
+
self.deque.append(value)
|
26 |
+
self.count += n
|
27 |
+
self.total += value * n
|
28 |
+
|
29 |
+
def synchronize_between_processes(self):
|
30 |
+
"""
|
31 |
+
Warning: does not synchronize the deque!
|
32 |
+
"""
|
33 |
+
if not is_dist_avail_and_initialized():
|
34 |
+
return
|
35 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
36 |
+
dist.barrier()
|
37 |
+
dist.all_reduce(t)
|
38 |
+
t = t.tolist()
|
39 |
+
self.count = int(t[0])
|
40 |
+
self.total = t[1]
|
41 |
+
|
42 |
+
@property
|
43 |
+
def median(self):
|
44 |
+
d = torch.tensor(list(self.deque))
|
45 |
+
return d.median().item()
|
46 |
+
|
47 |
+
@property
|
48 |
+
def avg(self):
|
49 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
50 |
+
return d.mean().item()
|
51 |
+
|
52 |
+
@property
|
53 |
+
def global_avg(self):
|
54 |
+
return self.total / self.count
|
55 |
+
|
56 |
+
@property
|
57 |
+
def max(self):
|
58 |
+
return max(self.deque)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def value(self):
|
62 |
+
return self.deque[-1]
|
63 |
+
|
64 |
+
def __str__(self):
|
65 |
+
return self.fmt.format(
|
66 |
+
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def all_gather(data):
|
71 |
+
"""
|
72 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
73 |
+
Args:
|
74 |
+
data: any picklable object
|
75 |
+
Returns:
|
76 |
+
list[data]: list of data gathered from each rank
|
77 |
+
"""
|
78 |
+
world_size = get_world_size()
|
79 |
+
if world_size == 1:
|
80 |
+
return [data]
|
81 |
+
data_list = [None] * world_size
|
82 |
+
dist.all_gather_object(data_list, data)
|
83 |
+
return data_list
|
84 |
+
|
85 |
+
|
86 |
+
def reduce_dict(input_dict, average=True):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
input_dict (dict): all the values will be reduced
|
90 |
+
average (bool): whether to do average or sum
|
91 |
+
Reduce the values in the dictionary from all processes so that all processes
|
92 |
+
have the averaged results. Returns a dict with the same fields as
|
93 |
+
input_dict, after reduction.
|
94 |
+
"""
|
95 |
+
world_size = get_world_size()
|
96 |
+
if world_size < 2:
|
97 |
+
return input_dict
|
98 |
+
with torch.inference_mode():
|
99 |
+
names = []
|
100 |
+
values = []
|
101 |
+
# sort the keys so that they are consistent across processes
|
102 |
+
for k in sorted(input_dict.keys()):
|
103 |
+
names.append(k)
|
104 |
+
values.append(input_dict[k])
|
105 |
+
values = torch.stack(values, dim=0)
|
106 |
+
dist.all_reduce(values)
|
107 |
+
if average:
|
108 |
+
values /= world_size
|
109 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
110 |
+
return reduced_dict
|
111 |
+
|
112 |
+
|
113 |
+
class MetricLogger:
|
114 |
+
def __init__(self, delimiter="\t"):
|
115 |
+
self.meters = defaultdict(SmoothedValue)
|
116 |
+
self.delimiter = delimiter
|
117 |
+
|
118 |
+
def update(self, **kwargs):
|
119 |
+
for k, v in kwargs.items():
|
120 |
+
if isinstance(v, torch.Tensor):
|
121 |
+
v = v.item()
|
122 |
+
assert isinstance(v, (float, int))
|
123 |
+
self.meters[k].update(v)
|
124 |
+
|
125 |
+
def __getattr__(self, attr):
|
126 |
+
if attr in self.meters:
|
127 |
+
return self.meters[attr]
|
128 |
+
if attr in self.__dict__:
|
129 |
+
return self.__dict__[attr]
|
130 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
|
131 |
+
|
132 |
+
def __str__(self):
|
133 |
+
loss_str = []
|
134 |
+
for name, meter in self.meters.items():
|
135 |
+
loss_str.append(f"{name}: {str(meter)}")
|
136 |
+
return self.delimiter.join(loss_str)
|
137 |
+
|
138 |
+
def synchronize_between_processes(self):
|
139 |
+
for meter in self.meters.values():
|
140 |
+
meter.synchronize_between_processes()
|
141 |
+
|
142 |
+
def add_meter(self, name, meter):
|
143 |
+
self.meters[name] = meter
|
144 |
+
|
145 |
+
def log_every(self, iterable, print_freq, header=None):
|
146 |
+
i = 0
|
147 |
+
if not header:
|
148 |
+
header = ""
|
149 |
+
start_time = time.time()
|
150 |
+
end = time.time()
|
151 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
152 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
153 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
154 |
+
if torch.cuda.is_available():
|
155 |
+
log_msg = self.delimiter.join(
|
156 |
+
[
|
157 |
+
header,
|
158 |
+
"[{0" + space_fmt + "}/{1}]",
|
159 |
+
"eta: {eta}",
|
160 |
+
"{meters}",
|
161 |
+
"time: {time}",
|
162 |
+
"data: {data}",
|
163 |
+
"max mem: {memory:.0f}",
|
164 |
+
]
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
log_msg = self.delimiter.join(
|
168 |
+
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
169 |
+
)
|
170 |
+
MB = 1024.0 * 1024.0
|
171 |
+
for obj in iterable:
|
172 |
+
data_time.update(time.time() - end)
|
173 |
+
yield obj
|
174 |
+
iter_time.update(time.time() - end)
|
175 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
176 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
177 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
178 |
+
if torch.cuda.is_available():
|
179 |
+
print(
|
180 |
+
log_msg.format(
|
181 |
+
i,
|
182 |
+
len(iterable),
|
183 |
+
eta=eta_string,
|
184 |
+
meters=str(self),
|
185 |
+
time=str(iter_time),
|
186 |
+
data=str(data_time),
|
187 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
188 |
+
)
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
print(
|
192 |
+
log_msg.format(
|
193 |
+
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
194 |
+
)
|
195 |
+
)
|
196 |
+
i += 1
|
197 |
+
end = time.time()
|
198 |
+
total_time = time.time() - start_time
|
199 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
200 |
+
print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
|
201 |
+
|
202 |
+
|
203 |
+
def collate_fn(batch):
|
204 |
+
return tuple(zip(*batch))
|
205 |
+
|
206 |
+
|
207 |
+
def mkdir(path):
|
208 |
+
try:
|
209 |
+
os.makedirs(path)
|
210 |
+
except OSError as e:
|
211 |
+
if e.errno != errno.EEXIST:
|
212 |
+
raise
|
213 |
+
|
214 |
+
|
215 |
+
def setup_for_distributed(is_master):
|
216 |
+
"""
|
217 |
+
This function disables printing when not in master process
|
218 |
+
"""
|
219 |
+
import builtins as __builtin__
|
220 |
+
|
221 |
+
builtin_print = __builtin__.print
|
222 |
+
|
223 |
+
def print(*args, **kwargs):
|
224 |
+
force = kwargs.pop("force", False)
|
225 |
+
if is_master or force:
|
226 |
+
builtin_print(*args, **kwargs)
|
227 |
+
|
228 |
+
__builtin__.print = print
|
229 |
+
|
230 |
+
|
231 |
+
def is_dist_avail_and_initialized():
|
232 |
+
if not dist.is_available():
|
233 |
+
return False
|
234 |
+
if not dist.is_initialized():
|
235 |
+
return False
|
236 |
+
return True
|
237 |
+
|
238 |
+
|
239 |
+
def get_world_size():
|
240 |
+
if not is_dist_avail_and_initialized():
|
241 |
+
return 1
|
242 |
+
return dist.get_world_size()
|
243 |
+
|
244 |
+
|
245 |
+
def get_rank():
|
246 |
+
if not is_dist_avail_and_initialized():
|
247 |
+
return 0
|
248 |
+
return dist.get_rank()
|
249 |
+
|
250 |
+
|
251 |
+
def is_main_process():
|
252 |
+
return get_rank() == 0
|
253 |
+
|
254 |
+
|
255 |
+
def save_on_master(*args, **kwargs):
|
256 |
+
if is_main_process():
|
257 |
+
torch.save(*args, **kwargs)
|
258 |
+
|
259 |
+
|
260 |
+
def init_distributed_mode(args):
|
261 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
262 |
+
args.rank = int(os.environ["RANK"])
|
263 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
264 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
265 |
+
elif "SLURM_PROCID" in os.environ:
|
266 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
267 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
268 |
+
else:
|
269 |
+
print("Not using distributed mode")
|
270 |
+
args.distributed = False
|
271 |
+
return
|
272 |
+
|
273 |
+
args.distributed = True
|
274 |
+
|
275 |
+
torch.cuda.set_device(args.gpu)
|
276 |
+
args.dist_backend = "nccl"
|
277 |
+
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
278 |
+
torch.distributed.init_process_group(
|
279 |
+
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
280 |
+
)
|
281 |
+
torch.distributed.barrier()
|
282 |
+
setup_for_distributed(args.rank == 0)
|
DenseMammogram/ensemble_boxes/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
|
3 |
+
|
4 |
+
from .ensemble_boxes_wbf import weighted_boxes_fusion
|
5 |
+
from .ensemble_boxes_nmw import non_maximum_weighted
|
6 |
+
from .ensemble_boxes_nms import nms_method
|
7 |
+
from .ensemble_boxes_nms import nms
|
8 |
+
from .ensemble_boxes_nms import soft_nms
|
9 |
+
from .ensemble_boxes_wbf_3d import weighted_boxes_fusion_3d
|
DenseMammogram/ensemble_boxes/ensemble_boxes_nms.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from numba import jit
|
6 |
+
|
7 |
+
|
8 |
+
def prepare_boxes(boxes, scores, labels):
|
9 |
+
result_boxes = boxes.copy()
|
10 |
+
|
11 |
+
cond = (result_boxes < 0)
|
12 |
+
cond_sum = cond.astype(np.int32).sum()
|
13 |
+
if cond_sum > 0:
|
14 |
+
print('Warning. Fixed {} boxes coordinates < 0'.format(cond_sum))
|
15 |
+
result_boxes[cond] = 0
|
16 |
+
|
17 |
+
cond = (result_boxes > 1)
|
18 |
+
cond_sum = cond.astype(np.int32).sum()
|
19 |
+
if cond_sum > 0:
|
20 |
+
print('Warning. Fixed {} boxes coordinates > 1. Check that your boxes was normalized at [0, 1]'.format(cond_sum))
|
21 |
+
result_boxes[cond] = 1
|
22 |
+
|
23 |
+
boxes1 = result_boxes.copy()
|
24 |
+
result_boxes[:, 0] = np.min(boxes1[:, [0, 2]], axis=1)
|
25 |
+
result_boxes[:, 2] = np.max(boxes1[:, [0, 2]], axis=1)
|
26 |
+
result_boxes[:, 1] = np.min(boxes1[:, [1, 3]], axis=1)
|
27 |
+
result_boxes[:, 3] = np.max(boxes1[:, [1, 3]], axis=1)
|
28 |
+
|
29 |
+
area = (result_boxes[:, 2] - result_boxes[:, 0]) * (result_boxes[:, 3] - result_boxes[:, 1])
|
30 |
+
cond = (area == 0)
|
31 |
+
cond_sum = cond.astype(np.int32).sum()
|
32 |
+
if cond_sum > 0:
|
33 |
+
print('Warning. Removed {} boxes with zero area!'.format(cond_sum))
|
34 |
+
result_boxes = result_boxes[area > 0]
|
35 |
+
scores = scores[area > 0]
|
36 |
+
labels = labels[area > 0]
|
37 |
+
|
38 |
+
return result_boxes, scores, labels
|
39 |
+
|
40 |
+
|
41 |
+
def cpu_soft_nms_float(dets, sc, Nt, sigma, thresh, method):
|
42 |
+
"""
|
43 |
+
Based on: https://github.com/DocF/Soft-NMS/blob/master/soft_nms.py
|
44 |
+
It's different from original soft-NMS because we have float coordinates on range [0; 1]
|
45 |
+
|
46 |
+
:param dets: boxes format [x1, y1, x2, y2]
|
47 |
+
:param sc: scores for boxes
|
48 |
+
:param Nt: required iou
|
49 |
+
:param sigma:
|
50 |
+
:param thresh:
|
51 |
+
:param method: 1 - linear soft-NMS, 2 - gaussian soft-NMS, 3 - standard NMS
|
52 |
+
:return: index of boxes to keep
|
53 |
+
"""
|
54 |
+
|
55 |
+
# indexes concatenate boxes with the last column
|
56 |
+
N = dets.shape[0]
|
57 |
+
indexes = np.array([np.arange(N)])
|
58 |
+
dets = np.concatenate((dets, indexes.T), axis=1)
|
59 |
+
|
60 |
+
# the order of boxes coordinate is [y1, x1, y2, x2]
|
61 |
+
y1 = dets[:, 1]
|
62 |
+
x1 = dets[:, 0]
|
63 |
+
y2 = dets[:, 3]
|
64 |
+
x2 = dets[:, 2]
|
65 |
+
scores = sc
|
66 |
+
areas = (x2 - x1) * (y2 - y1)
|
67 |
+
|
68 |
+
for i in range(N):
|
69 |
+
# intermediate parameters for later parameters exchange
|
70 |
+
tBD = dets[i, :].copy()
|
71 |
+
tscore = scores[i].copy()
|
72 |
+
tarea = areas[i].copy()
|
73 |
+
pos = i + 1
|
74 |
+
|
75 |
+
#
|
76 |
+
if i != N - 1:
|
77 |
+
maxscore = np.max(scores[pos:], axis=0)
|
78 |
+
maxpos = np.argmax(scores[pos:], axis=0)
|
79 |
+
else:
|
80 |
+
maxscore = scores[-1]
|
81 |
+
maxpos = 0
|
82 |
+
if tscore < maxscore:
|
83 |
+
dets[i, :] = dets[maxpos + i + 1, :]
|
84 |
+
dets[maxpos + i + 1, :] = tBD
|
85 |
+
tBD = dets[i, :]
|
86 |
+
|
87 |
+
scores[i] = scores[maxpos + i + 1]
|
88 |
+
scores[maxpos + i + 1] = tscore
|
89 |
+
tscore = scores[i]
|
90 |
+
|
91 |
+
areas[i] = areas[maxpos + i + 1]
|
92 |
+
areas[maxpos + i + 1] = tarea
|
93 |
+
tarea = areas[i]
|
94 |
+
|
95 |
+
# IoU calculate
|
96 |
+
xx1 = np.maximum(dets[i, 1], dets[pos:, 1])
|
97 |
+
yy1 = np.maximum(dets[i, 0], dets[pos:, 0])
|
98 |
+
xx2 = np.minimum(dets[i, 3], dets[pos:, 3])
|
99 |
+
yy2 = np.minimum(dets[i, 2], dets[pos:, 2])
|
100 |
+
|
101 |
+
w = np.maximum(0.0, xx2 - xx1)
|
102 |
+
h = np.maximum(0.0, yy2 - yy1)
|
103 |
+
inter = w * h
|
104 |
+
ovr = inter / (areas[i] + areas[pos:] - inter)
|
105 |
+
|
106 |
+
# Three methods: 1.linear 2.gaussian 3.original NMS
|
107 |
+
if method == 1: # linear
|
108 |
+
weight = np.ones(ovr.shape)
|
109 |
+
weight[ovr > Nt] = weight[ovr > Nt] - ovr[ovr > Nt]
|
110 |
+
elif method == 2: # gaussian
|
111 |
+
weight = np.exp(-(ovr * ovr) / sigma)
|
112 |
+
else: # original NMS
|
113 |
+
weight = np.ones(ovr.shape)
|
114 |
+
weight[ovr > Nt] = 0
|
115 |
+
|
116 |
+
scores[pos:] = weight * scores[pos:]
|
117 |
+
|
118 |
+
# select the boxes and keep the corresponding indexes
|
119 |
+
inds = dets[:, 4][scores > thresh]
|
120 |
+
keep = inds.astype(int)
|
121 |
+
return keep
|
122 |
+
|
123 |
+
|
124 |
+
@jit(nopython=True)
|
125 |
+
def nms_float_fast(dets, scores, thresh):
|
126 |
+
"""
|
127 |
+
# It's different from original nms because we have float coordinates on range [0; 1]
|
128 |
+
:param dets: numpy array of boxes with shape: (N, 5). Order: x1, y1, x2, y2, score. All variables in range [0; 1]
|
129 |
+
:param thresh: IoU value for boxes
|
130 |
+
:return: index of boxes to keep
|
131 |
+
"""
|
132 |
+
x1 = dets[:, 0]
|
133 |
+
y1 = dets[:, 1]
|
134 |
+
x2 = dets[:, 2]
|
135 |
+
y2 = dets[:, 3]
|
136 |
+
|
137 |
+
areas = (x2 - x1) * (y2 - y1)
|
138 |
+
order = scores.argsort()[::-1]
|
139 |
+
|
140 |
+
keep = []
|
141 |
+
while order.size > 0:
|
142 |
+
i = order[0]
|
143 |
+
keep.append(i)
|
144 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
145 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
146 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
147 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
148 |
+
|
149 |
+
w = np.maximum(0.0, xx2 - xx1)
|
150 |
+
h = np.maximum(0.0, yy2 - yy1)
|
151 |
+
inter = w * h
|
152 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
153 |
+
inds = np.where(ovr <= thresh)[0]
|
154 |
+
order = order[inds + 1]
|
155 |
+
|
156 |
+
return keep
|
157 |
+
|
158 |
+
|
159 |
+
def nms_method(boxes, scores, labels, method=3, iou_thr=0.5, sigma=0.5, thresh=0.001, weights=None):
|
160 |
+
"""
|
161 |
+
:param boxes: list of boxes predictions from each model, each box is 4 numbers.
|
162 |
+
It has 3 dimensions (models_number, model_preds, 4)
|
163 |
+
Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1]
|
164 |
+
:param scores: list of scores for each model
|
165 |
+
:param labels: list of labels for each model
|
166 |
+
:param method: 1 - linear soft-NMS, 2 - gaussian soft-NMS, 3 - standard NMS
|
167 |
+
:param iou_thr: IoU value for boxes to be a match
|
168 |
+
:param sigma: Sigma value for SoftNMS
|
169 |
+
:param thresh: threshold for boxes to keep (important for SoftNMS)
|
170 |
+
:param weights: list of weights for each model. Default: None, which means weight == 1 for each model
|
171 |
+
|
172 |
+
:return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2).
|
173 |
+
:return: scores: confidence scores
|
174 |
+
:return: labels: boxes labels
|
175 |
+
"""
|
176 |
+
|
177 |
+
# If weights are specified
|
178 |
+
if weights is not None:
|
179 |
+
if len(boxes) != len(weights):
|
180 |
+
print('Incorrect number of weights: {}. Must be: {}. Skip it'.format(len(weights), len(boxes)))
|
181 |
+
else:
|
182 |
+
weights = np.array(weights)
|
183 |
+
for i in range(len(weights)):
|
184 |
+
scores[i] = (np.array(scores[i]) * weights[i]) / weights.sum()
|
185 |
+
|
186 |
+
# We concatenate everything
|
187 |
+
boxes = np.concatenate(boxes)
|
188 |
+
scores = np.concatenate(scores)
|
189 |
+
labels = np.concatenate(labels)
|
190 |
+
|
191 |
+
# Fix coordinates and removed zero area boxes
|
192 |
+
boxes, scores, labels = prepare_boxes(boxes, scores, labels)
|
193 |
+
|
194 |
+
# Run NMS independently for each label
|
195 |
+
unique_labels = np.unique(labels)
|
196 |
+
final_boxes = []
|
197 |
+
final_scores = []
|
198 |
+
final_labels = []
|
199 |
+
for l in unique_labels:
|
200 |
+
condition = (labels == l)
|
201 |
+
boxes_by_label = boxes[condition]
|
202 |
+
scores_by_label = scores[condition]
|
203 |
+
labels_by_label = np.array([l] * len(boxes_by_label))
|
204 |
+
|
205 |
+
if method != 3:
|
206 |
+
keep = cpu_soft_nms_float(boxes_by_label.copy(), scores_by_label.copy(), Nt=iou_thr, sigma=sigma, thresh=thresh, method=method)
|
207 |
+
else:
|
208 |
+
# Use faster function
|
209 |
+
keep = nms_float_fast(boxes_by_label, scores_by_label, thresh=iou_thr)
|
210 |
+
|
211 |
+
final_boxes.append(boxes_by_label[keep])
|
212 |
+
final_scores.append(scores_by_label[keep])
|
213 |
+
final_labels.append(labels_by_label[keep])
|
214 |
+
final_boxes = np.concatenate(final_boxes)
|
215 |
+
final_scores = np.concatenate(final_scores)
|
216 |
+
final_labels = np.concatenate(final_labels)
|
217 |
+
|
218 |
+
return final_boxes, final_scores, final_labels
|
219 |
+
|
220 |
+
|
221 |
+
def nms(boxes, scores, labels, iou_thr=0.5, weights=None):
|
222 |
+
"""
|
223 |
+
Short call for standard NMS
|
224 |
+
|
225 |
+
:param boxes:
|
226 |
+
:param scores:
|
227 |
+
:param labels:
|
228 |
+
:param iou_thr:
|
229 |
+
:param weights:
|
230 |
+
:return:
|
231 |
+
"""
|
232 |
+
return nms_method(boxes, scores, labels, method=3, iou_thr=iou_thr, weights=weights)
|
233 |
+
|
234 |
+
|
235 |
+
def soft_nms(boxes, scores, labels, method=2, iou_thr=0.5, sigma=0.5, thresh=0.001, weights=None):
|
236 |
+
"""
|
237 |
+
Short call for Soft-NMS
|
238 |
+
|
239 |
+
:param boxes:
|
240 |
+
:param scores:
|
241 |
+
:param labels:
|
242 |
+
:param method:
|
243 |
+
:param iou_thr:
|
244 |
+
:param sigma:
|
245 |
+
:param thresh:
|
246 |
+
:param weights:
|
247 |
+
:return:
|
248 |
+
"""
|
249 |
+
return nms_method(boxes, scores, labels, method=method, iou_thr=iou_thr, sigma=sigma, thresh=thresh, weights=weights)
|
DenseMammogram/ensemble_boxes/ensemble_boxes_nmw.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
|
3 |
+
|
4 |
+
"""
|
5 |
+
Method described in:
|
6 |
+
CAD: Scale Invariant Framework for Real-Time Object Detection
|
7 |
+
http://openaccess.thecvf.com/content_ICCV_2017_workshops/papers/w14/Zhou_CAD_Scale_Invariant_ICCV_2017_paper.pdf
|
8 |
+
"""
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
import numpy as np
|
12 |
+
from numba import jit
|
13 |
+
|
14 |
+
|
15 |
+
@jit(nopython=True)
|
16 |
+
def bb_intersection_over_union(A, B):
|
17 |
+
xA = max(A[0], B[0])
|
18 |
+
yA = max(A[1], B[1])
|
19 |
+
xB = min(A[2], B[2])
|
20 |
+
yB = min(A[3], B[3])
|
21 |
+
|
22 |
+
# compute the area of intersection rectangle
|
23 |
+
interArea = max(0, xB - xA) * max(0, yB - yA)
|
24 |
+
|
25 |
+
if interArea == 0:
|
26 |
+
return 0.0
|
27 |
+
|
28 |
+
# compute the area of both the prediction and ground-truth rectangles
|
29 |
+
boxAArea = (A[2] - A[0]) * (A[3] - A[1])
|
30 |
+
boxBArea = (B[2] - B[0]) * (B[3] - B[1])
|
31 |
+
|
32 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
33 |
+
return iou
|
34 |
+
|
35 |
+
|
36 |
+
def prefilter_boxes(boxes, scores, labels, weights, thr):
|
37 |
+
# Create dict with boxes stored by its label
|
38 |
+
new_boxes = dict()
|
39 |
+
for t in range(len(boxes)):
|
40 |
+
|
41 |
+
if len(boxes[t]) != len(scores[t]):
|
42 |
+
print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]),
|
43 |
+
len(scores[t])))
|
44 |
+
exit()
|
45 |
+
|
46 |
+
if len(boxes[t]) != len(labels[t]):
|
47 |
+
print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]),
|
48 |
+
len(labels[t])))
|
49 |
+
exit()
|
50 |
+
|
51 |
+
for j in range(len(boxes[t])):
|
52 |
+
score = scores[t][j]
|
53 |
+
if score < thr:
|
54 |
+
continue
|
55 |
+
label = int(labels[t][j])
|
56 |
+
box_part = boxes[t][j]
|
57 |
+
x1 = float(box_part[0])
|
58 |
+
y1 = float(box_part[1])
|
59 |
+
x2 = float(box_part[2])
|
60 |
+
y2 = float(box_part[3])
|
61 |
+
|
62 |
+
# Box data checks
|
63 |
+
if x2 < x1:
|
64 |
+
warnings.warn('X2 < X1 value in box. Swap them.')
|
65 |
+
x1, x2 = x2, x1
|
66 |
+
if y2 < y1:
|
67 |
+
warnings.warn('Y2 < Y1 value in box. Swap them.')
|
68 |
+
y1, y2 = y2, y1
|
69 |
+
if x1 < 0:
|
70 |
+
warnings.warn('X1 < 0 in box. Set it to 0.')
|
71 |
+
x1 = 0
|
72 |
+
if x1 > 1:
|
73 |
+
warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
74 |
+
x1 = 1
|
75 |
+
if x2 < 0:
|
76 |
+
warnings.warn('X2 < 0 in box. Set it to 0.')
|
77 |
+
x2 = 0
|
78 |
+
if x2 > 1:
|
79 |
+
warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
80 |
+
x2 = 1
|
81 |
+
if y1 < 0:
|
82 |
+
warnings.warn('Y1 < 0 in box. Set it to 0.')
|
83 |
+
y1 = 0
|
84 |
+
if y1 > 1:
|
85 |
+
warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
86 |
+
y1 = 1
|
87 |
+
if y2 < 0:
|
88 |
+
warnings.warn('Y2 < 0 in box. Set it to 0.')
|
89 |
+
y2 = 0
|
90 |
+
if y2 > 1:
|
91 |
+
warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
92 |
+
y2 = 1
|
93 |
+
if (x2 - x1) * (y2 - y1) == 0.0:
|
94 |
+
warnings.warn("Zero area box skipped: {}.".format(box_part))
|
95 |
+
continue
|
96 |
+
|
97 |
+
b = [int(label), float(score) * weights[t], x1, y1, x2, y2]
|
98 |
+
if label not in new_boxes:
|
99 |
+
new_boxes[label] = []
|
100 |
+
new_boxes[label].append(b)
|
101 |
+
|
102 |
+
# Sort each list in dict by score and transform it to numpy array
|
103 |
+
for k in new_boxes:
|
104 |
+
current_boxes = np.array(new_boxes[k])
|
105 |
+
new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
|
106 |
+
|
107 |
+
return new_boxes
|
108 |
+
|
109 |
+
|
110 |
+
def get_weighted_box(boxes):
|
111 |
+
"""
|
112 |
+
Create weighted box for set of boxes
|
113 |
+
:param boxes: set of boxes to fuse
|
114 |
+
:return: weighted box
|
115 |
+
"""
|
116 |
+
|
117 |
+
box = np.zeros(6, dtype=np.float32)
|
118 |
+
best_box = boxes[0]
|
119 |
+
conf = 0
|
120 |
+
for b in boxes:
|
121 |
+
iou = bb_intersection_over_union(b[2:], best_box[2:])
|
122 |
+
weight = b[1] * iou
|
123 |
+
box[2:] += (weight * b[2:])
|
124 |
+
conf += weight
|
125 |
+
box[0] = best_box[0]
|
126 |
+
box[1] = best_box[1]
|
127 |
+
box[2:] /= conf
|
128 |
+
return box
|
129 |
+
|
130 |
+
|
131 |
+
def find_matching_box(boxes_list, new_box, match_iou):
|
132 |
+
best_iou = match_iou
|
133 |
+
best_index = -1
|
134 |
+
for i in range(len(boxes_list)):
|
135 |
+
box = boxes_list[i]
|
136 |
+
if box[0] != new_box[0]:
|
137 |
+
continue
|
138 |
+
iou = bb_intersection_over_union(box[2:], new_box[2:])
|
139 |
+
if iou > best_iou:
|
140 |
+
best_index = i
|
141 |
+
best_iou = iou
|
142 |
+
|
143 |
+
return best_index, best_iou
|
144 |
+
|
145 |
+
|
146 |
+
def non_maximum_weighted(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0):
|
147 |
+
'''
|
148 |
+
:param boxes_list: list of boxes predictions from each model, each box is 4 numbers.
|
149 |
+
It has 3 dimensions (models_number, model_preds, 4)
|
150 |
+
Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1]
|
151 |
+
:param scores_list: list of scores for each model
|
152 |
+
:param labels_list: list of labels for each model
|
153 |
+
:param weights: list of weights for each model. Default: None, which means weight == 1 for each model
|
154 |
+
:param iou_thr: IoU value for boxes to be a match
|
155 |
+
:param skip_box_thr: exclude boxes with score lower than this variable
|
156 |
+
|
157 |
+
:return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2).
|
158 |
+
:return: scores: confidence scores
|
159 |
+
:return: labels: boxes labels
|
160 |
+
'''
|
161 |
+
|
162 |
+
if weights is None:
|
163 |
+
weights = np.ones(len(boxes_list))
|
164 |
+
if len(weights) != len(boxes_list):
|
165 |
+
print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
|
166 |
+
weights = np.ones(len(boxes_list))
|
167 |
+
weights = np.array(weights) / max(weights)
|
168 |
+
# for i in range(len(weights)):
|
169 |
+
# scores_list[i] = (np.array(scores_list[i]) * weights[i])
|
170 |
+
|
171 |
+
filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
|
172 |
+
if len(filtered_boxes) == 0:
|
173 |
+
return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
|
174 |
+
|
175 |
+
overall_boxes = []
|
176 |
+
for label in filtered_boxes:
|
177 |
+
boxes = filtered_boxes[label]
|
178 |
+
new_boxes = []
|
179 |
+
main_boxes = []
|
180 |
+
|
181 |
+
# Clusterize boxes
|
182 |
+
for j in range(0, len(boxes)):
|
183 |
+
index, best_iou = find_matching_box(main_boxes, boxes[j], iou_thr)
|
184 |
+
if index != -1:
|
185 |
+
new_boxes[index].append(boxes[j].copy())
|
186 |
+
else:
|
187 |
+
new_boxes.append([boxes[j].copy()])
|
188 |
+
main_boxes.append(boxes[j].copy())
|
189 |
+
|
190 |
+
weighted_boxes = []
|
191 |
+
for j in range(0, len(new_boxes)):
|
192 |
+
box = get_weighted_box(new_boxes[j])
|
193 |
+
weighted_boxes.append(box.copy())
|
194 |
+
|
195 |
+
overall_boxes.append(np.array(weighted_boxes))
|
196 |
+
|
197 |
+
overall_boxes = np.concatenate(overall_boxes, axis=0)
|
198 |
+
overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
|
199 |
+
boxes = overall_boxes[:, 2:]
|
200 |
+
scores = overall_boxes[:, 1]
|
201 |
+
labels = overall_boxes[:, 0]
|
202 |
+
return boxes, scores, labels
|
DenseMammogram/ensemble_boxes/ensemble_boxes_wbf.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
|
3 |
+
|
4 |
+
|
5 |
+
import warnings
|
6 |
+
import numpy as np
|
7 |
+
from numba import jit
|
8 |
+
import time
|
9 |
+
|
10 |
+
@jit(nopython=True)
|
11 |
+
def bb_intersection_over_union(A, B) -> float:
|
12 |
+
xA = max(A[0], B[0])
|
13 |
+
yA = max(A[1], B[1])
|
14 |
+
xB = min(A[2], B[2])
|
15 |
+
yB = min(A[3], B[3])
|
16 |
+
|
17 |
+
# compute the area of intersection rectangle
|
18 |
+
interArea = max(0, xB - xA) * max(0, yB - yA)
|
19 |
+
|
20 |
+
if interArea == 0:
|
21 |
+
return 0.0
|
22 |
+
|
23 |
+
# compute the area of both the prediction and ground-truth rectangles
|
24 |
+
boxAArea = (A[2] - A[0]) * (A[3] - A[1])
|
25 |
+
boxBArea = (B[2] - B[0]) * (B[3] - B[1])
|
26 |
+
|
27 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
28 |
+
return iou
|
29 |
+
|
30 |
+
|
31 |
+
def prefilter_boxes(boxes, scores, labels, weights, thr):
|
32 |
+
# Create dict with boxes stored by its label
|
33 |
+
new_boxes = dict()
|
34 |
+
|
35 |
+
for t in range(len(boxes)):
|
36 |
+
|
37 |
+
if len(boxes[t]) != len(scores[t]):
|
38 |
+
print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]), len(scores[t])))
|
39 |
+
exit()
|
40 |
+
|
41 |
+
if len(boxes[t]) != len(labels[t]):
|
42 |
+
print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]), len(labels[t])))
|
43 |
+
exit()
|
44 |
+
|
45 |
+
for j in range(len(boxes[t])):
|
46 |
+
score = scores[t][j]
|
47 |
+
if score < thr:
|
48 |
+
continue
|
49 |
+
label = int(labels[t][j])
|
50 |
+
box_part = boxes[t][j]
|
51 |
+
x1 = float(box_part[0])
|
52 |
+
y1 = float(box_part[1])
|
53 |
+
x2 = float(box_part[2])
|
54 |
+
y2 = float(box_part[3])
|
55 |
+
|
56 |
+
# Box data checks
|
57 |
+
if x2 < x1:
|
58 |
+
warnings.warn('X2 < X1 value in box. Swap them.')
|
59 |
+
x1, x2 = x2, x1
|
60 |
+
if y2 < y1:
|
61 |
+
warnings.warn('Y2 < Y1 value in box. Swap them.')
|
62 |
+
y1, y2 = y2, y1
|
63 |
+
if x1 < 0:
|
64 |
+
warnings.warn('X1 < 0 in box. Set it to 0.')
|
65 |
+
x1 = 0
|
66 |
+
if x1 > 1:
|
67 |
+
warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
68 |
+
x1 = 1
|
69 |
+
if x2 < 0:
|
70 |
+
warnings.warn('X2 < 0 in box. Set it to 0.')
|
71 |
+
x2 = 0
|
72 |
+
if x2 > 1:
|
73 |
+
warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
74 |
+
x2 = 1
|
75 |
+
if y1 < 0:
|
76 |
+
warnings.warn('Y1 < 0 in box. Set it to 0.')
|
77 |
+
y1 = 0
|
78 |
+
if y1 > 1:
|
79 |
+
warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
80 |
+
y1 = 1
|
81 |
+
if y2 < 0:
|
82 |
+
warnings.warn('Y2 < 0 in box. Set it to 0.')
|
83 |
+
y2 = 0
|
84 |
+
if y2 > 1:
|
85 |
+
warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
86 |
+
y2 = 1
|
87 |
+
if (x2 - x1) * (y2 - y1) == 0.0:
|
88 |
+
warnings.warn("Zero area box skipped: {}.".format(box_part))
|
89 |
+
continue
|
90 |
+
|
91 |
+
# [label, score, weight, model index, x1, y1, x2, y2]
|
92 |
+
b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
|
93 |
+
if label not in new_boxes:
|
94 |
+
new_boxes[label] = []
|
95 |
+
new_boxes[label].append(b)
|
96 |
+
|
97 |
+
# Sort each list in dict by score and transform it to numpy array
|
98 |
+
for k in new_boxes:
|
99 |
+
current_boxes = np.array(new_boxes[k])
|
100 |
+
new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
|
101 |
+
|
102 |
+
return new_boxes
|
103 |
+
|
104 |
+
|
105 |
+
def get_weighted_box(boxes, conf_type='avg'):
|
106 |
+
"""
|
107 |
+
Create weighted box for set of boxes
|
108 |
+
:param boxes: set of boxes to fuse
|
109 |
+
:param conf_type: type of confidence one of 'avg' or 'max'
|
110 |
+
:return: weighted box (label, score, weight, x1, y1, x2, y2)
|
111 |
+
"""
|
112 |
+
|
113 |
+
box = np.zeros(8, dtype=np.float32)
|
114 |
+
conf = 0
|
115 |
+
conf_list = []
|
116 |
+
w = 0
|
117 |
+
for b in boxes:
|
118 |
+
box[4:] += (b[1] * b[4:])
|
119 |
+
conf += b[1]
|
120 |
+
conf_list.append(b[1])
|
121 |
+
w += b[2]
|
122 |
+
box[0] = boxes[0][0]
|
123 |
+
if conf_type == 'avg':
|
124 |
+
box[1] = conf / len(boxes)
|
125 |
+
elif conf_type == 'max':
|
126 |
+
box[1] = np.array(conf_list).max()
|
127 |
+
elif conf_type in ['box_and_model_avg', 'absent_model_aware_avg']:
|
128 |
+
box[1] = conf / len(boxes)
|
129 |
+
box[2] = w
|
130 |
+
box[3] = -1 # model index field is retained for consistensy but is not used.
|
131 |
+
box[4:] /= conf
|
132 |
+
return box
|
133 |
+
|
134 |
+
|
135 |
+
def find_matching_box(boxes_list, new_box, match_iou):
|
136 |
+
best_iou = match_iou
|
137 |
+
best_index = -1
|
138 |
+
for i in range(len(boxes_list)):
|
139 |
+
box = boxes_list[i]
|
140 |
+
if box[0] != new_box[0]:
|
141 |
+
continue
|
142 |
+
iou = bb_intersection_over_union(box[4:], new_box[4:])
|
143 |
+
if iou > best_iou:
|
144 |
+
best_index = i
|
145 |
+
best_iou = iou
|
146 |
+
|
147 |
+
return best_index, best_iou
|
148 |
+
|
149 |
+
|
150 |
+
def find_matching_box_quickly(boxes_list, new_box, match_iou):
|
151 |
+
""" Reimplementation of find_matching_box with numpy instead of loops. Gives significant speed up for larger arrays
|
152 |
+
(~100x). This was previously the bottleneck since the function is called for every entry in the array.
|
153 |
+
"""
|
154 |
+
def bb_iou_array(boxes, new_box):
|
155 |
+
# bb interesection over union
|
156 |
+
xA = np.maximum(boxes[:, 0], new_box[0])
|
157 |
+
yA = np.maximum(boxes[:, 1], new_box[1])
|
158 |
+
xB = np.minimum(boxes[:, 2], new_box[2])
|
159 |
+
yB = np.minimum(boxes[:, 3], new_box[3])
|
160 |
+
|
161 |
+
interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
|
162 |
+
|
163 |
+
# compute the area of both the prediction and ground-truth rectangles
|
164 |
+
boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
165 |
+
boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
|
166 |
+
|
167 |
+
iou = interArea / (boxAArea + boxBArea - interArea)
|
168 |
+
|
169 |
+
return iou
|
170 |
+
|
171 |
+
if boxes_list.shape[0] == 0:
|
172 |
+
return -1, match_iou
|
173 |
+
|
174 |
+
# boxes = np.array(boxes_list)
|
175 |
+
boxes = boxes_list
|
176 |
+
|
177 |
+
ious = bb_iou_array(boxes[:, 4:], new_box[4:])
|
178 |
+
|
179 |
+
ious[boxes[:, 0] != new_box[0]] = -1
|
180 |
+
|
181 |
+
best_idx = np.argmax(ious)
|
182 |
+
best_iou = ious[best_idx]
|
183 |
+
|
184 |
+
if best_iou <= match_iou:
|
185 |
+
best_iou = match_iou
|
186 |
+
best_idx = -1
|
187 |
+
|
188 |
+
return best_idx, best_iou
|
189 |
+
|
190 |
+
|
191 |
+
def weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False):
|
192 |
+
'''
|
193 |
+
:param boxes_list: list of boxes predictions from each model, each box is 4 numbers.
|
194 |
+
It has 3 dimensions (models_number, model_preds, 4)
|
195 |
+
Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1]
|
196 |
+
:param scores_list: list of scores for each model
|
197 |
+
:param labels_list: list of labels for each model
|
198 |
+
:param weights: list of weights for each model. Default: None, which means weight == 1 for each model
|
199 |
+
:param iou_thr: IoU value for boxes to be a match
|
200 |
+
:param skip_box_thr: exclude boxes with score lower than this variable
|
201 |
+
:param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value, 'box_and_model_avg': box and model wise hybrid weighted average, 'absent_model_aware_avg': weighted average that takes into account the absent model.
|
202 |
+
:param allows_overflow: false if we want confidence score not exceed 1.0
|
203 |
+
|
204 |
+
:return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2).
|
205 |
+
:return: scores: confidence scores
|
206 |
+
:return: labels: boxes labels
|
207 |
+
'''
|
208 |
+
|
209 |
+
if weights is None:
|
210 |
+
weights = np.ones(len(boxes_list))
|
211 |
+
if len(weights) != len(boxes_list):
|
212 |
+
print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
|
213 |
+
weights = np.ones(len(boxes_list))
|
214 |
+
weights = np.array(weights)
|
215 |
+
|
216 |
+
if conf_type not in ['avg', 'max', 'box_and_model_avg', 'absent_model_aware_avg']:
|
217 |
+
print('Unknown conf_type: {}. Must be "avg", "max" or "box_and_model_avg", or "absent_model_aware_avg"'.format(conf_type))
|
218 |
+
exit()
|
219 |
+
|
220 |
+
filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
|
221 |
+
if len(filtered_boxes) == 0:
|
222 |
+
return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
|
223 |
+
|
224 |
+
overall_boxes = []
|
225 |
+
for label in filtered_boxes:
|
226 |
+
boxes = filtered_boxes[label]
|
227 |
+
new_boxes = []
|
228 |
+
weighted_boxes = np.empty((0,8))
|
229 |
+
# Clusterize boxes
|
230 |
+
for j in range(0, len(boxes)):
|
231 |
+
index, best_iou = find_matching_box_quickly(weighted_boxes, boxes[j], iou_thr)
|
232 |
+
|
233 |
+
if index != -1:
|
234 |
+
new_boxes[index].append(boxes[j])
|
235 |
+
weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
|
236 |
+
else:
|
237 |
+
new_boxes.append([boxes[j].copy()])
|
238 |
+
weighted_boxes = np.vstack((weighted_boxes, boxes[j].copy()))
|
239 |
+
# Rescale confidence based on number of models and boxes
|
240 |
+
for i in range(len(new_boxes)):
|
241 |
+
clustered_boxes = np.array(new_boxes[i])
|
242 |
+
if conf_type == 'box_and_model_avg':
|
243 |
+
# weighted average for boxes
|
244 |
+
weighted_boxes[i, 1] = weighted_boxes[i, 1] * len(clustered_boxes) / weighted_boxes[i, 2]
|
245 |
+
# identify unique model index by model index column
|
246 |
+
_, idx = np.unique(clustered_boxes[:, 3], return_index=True)
|
247 |
+
# rescale by unique model weights
|
248 |
+
weighted_boxes[i, 1] = weighted_boxes[i, 1] * clustered_boxes[idx, 2].sum() / weights.sum()
|
249 |
+
elif conf_type == 'absent_model_aware_avg':
|
250 |
+
# get unique model index in the cluster
|
251 |
+
models = np.unique(clustered_boxes[:, 3]).astype(int)
|
252 |
+
# create a mask to get unused model weights
|
253 |
+
mask = np.ones(len(weights), dtype=bool)
|
254 |
+
mask[models] = False
|
255 |
+
# absent model aware weighted average
|
256 |
+
weighted_boxes[i, 1] = weighted_boxes[i, 1] * len(clustered_boxes) / (weighted_boxes[i, 2] + weights[mask].sum())
|
257 |
+
elif conf_type == 'max':
|
258 |
+
weighted_boxes[i, 1] = weighted_boxes[i, 1] / weights.max()
|
259 |
+
elif not allows_overflow:
|
260 |
+
weighted_boxes[i, 1] = weighted_boxes[i, 1] * min(len(weights), len(clustered_boxes)) / weights.sum()
|
261 |
+
else:
|
262 |
+
weighted_boxes[i, 1] = weighted_boxes[i, 1] * len(clustered_boxes) / weights.sum()
|
263 |
+
overall_boxes.append(weighted_boxes)
|
264 |
+
overall_boxes = np.concatenate(overall_boxes, axis=0)
|
265 |
+
overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
|
266 |
+
boxes = overall_boxes[:, 4:]
|
267 |
+
scores = overall_boxes[:, 1]
|
268 |
+
labels = overall_boxes[:, 0]
|
269 |
+
return boxes, scores, labels
|
DenseMammogram/ensemble_boxes/ensemble_boxes_wbf_3d.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
|
3 |
+
|
4 |
+
|
5 |
+
import warnings
|
6 |
+
import numpy as np
|
7 |
+
from numba import jit
|
8 |
+
|
9 |
+
|
10 |
+
@jit(nopython=True)
|
11 |
+
def bb_intersection_over_union_3d(A, B) -> float:
|
12 |
+
xA = max(A[0], B[0])
|
13 |
+
yA = max(A[1], B[1])
|
14 |
+
zA = max(A[2], B[2])
|
15 |
+
xB = min(A[3], B[3])
|
16 |
+
yB = min(A[4], B[4])
|
17 |
+
zB = min(A[5], B[5])
|
18 |
+
|
19 |
+
interVol = max(0, xB - xA) * max(0, yB - yA) * max(0, zB - zA)
|
20 |
+
if interVol == 0:
|
21 |
+
return 0.0
|
22 |
+
|
23 |
+
# compute the volume of both the prediction and ground-truth rectangular boxes
|
24 |
+
boxAVol = (A[3] - A[0]) * (A[4] - A[1]) * (A[5] - A[2])
|
25 |
+
boxBVol = (B[3] - B[0]) * (B[4] - B[1]) * (B[5] - B[2])
|
26 |
+
|
27 |
+
iou = interVol / float(boxAVol + boxBVol - interVol)
|
28 |
+
return iou
|
29 |
+
|
30 |
+
|
31 |
+
def prefilter_boxes(boxes, scores, labels, weights, thr):
|
32 |
+
# Create dict with boxes stored by its label
|
33 |
+
new_boxes = dict()
|
34 |
+
|
35 |
+
for t in range(len(boxes)):
|
36 |
+
|
37 |
+
if len(boxes[t]) != len(scores[t]):
|
38 |
+
print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]), len(scores[t])))
|
39 |
+
exit()
|
40 |
+
|
41 |
+
if len(boxes[t]) != len(labels[t]):
|
42 |
+
print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]), len(labels[t])))
|
43 |
+
exit()
|
44 |
+
|
45 |
+
for j in range(len(boxes[t])):
|
46 |
+
score = scores[t][j]
|
47 |
+
if score < thr:
|
48 |
+
continue
|
49 |
+
label = int(labels[t][j])
|
50 |
+
box_part = boxes[t][j]
|
51 |
+
x1 = float(box_part[0])
|
52 |
+
y1 = float(box_part[1])
|
53 |
+
z1 = float(box_part[2])
|
54 |
+
x2 = float(box_part[3])
|
55 |
+
y2 = float(box_part[4])
|
56 |
+
z2 = float(box_part[5])
|
57 |
+
|
58 |
+
# Box data checks
|
59 |
+
if x2 < x1:
|
60 |
+
warnings.warn('X2 < X1 value in box. Swap them.')
|
61 |
+
x1, x2 = x2, x1
|
62 |
+
if y2 < y1:
|
63 |
+
warnings.warn('Y2 < Y1 value in box. Swap them.')
|
64 |
+
y1, y2 = y2, y1
|
65 |
+
if z2 < z1:
|
66 |
+
warnings.warn('Z2 < Z1 value in box. Swap them.')
|
67 |
+
z1, z2 = z2, z1
|
68 |
+
if x1 < 0:
|
69 |
+
warnings.warn('X1 < 0 in box. Set it to 0.')
|
70 |
+
x1 = 0
|
71 |
+
if x1 > 1:
|
72 |
+
warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
73 |
+
x1 = 1
|
74 |
+
if x2 < 0:
|
75 |
+
warnings.warn('X2 < 0 in box. Set it to 0.')
|
76 |
+
x2 = 0
|
77 |
+
if x2 > 1:
|
78 |
+
warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
79 |
+
x2 = 1
|
80 |
+
if y1 < 0:
|
81 |
+
warnings.warn('Y1 < 0 in box. Set it to 0.')
|
82 |
+
y1 = 0
|
83 |
+
if y1 > 1:
|
84 |
+
warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
85 |
+
y1 = 1
|
86 |
+
if y2 < 0:
|
87 |
+
warnings.warn('Y2 < 0 in box. Set it to 0.')
|
88 |
+
y2 = 0
|
89 |
+
if y2 > 1:
|
90 |
+
warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
91 |
+
y2 = 1
|
92 |
+
if z1 < 0:
|
93 |
+
warnings.warn('Z1 < 0 in box. Set it to 0.')
|
94 |
+
z1 = 0
|
95 |
+
if z1 > 1:
|
96 |
+
warnings.warn('Z1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
97 |
+
z1 = 1
|
98 |
+
if z2 < 0:
|
99 |
+
warnings.warn('Z2 < 0 in box. Set it to 0.')
|
100 |
+
z2 = 0
|
101 |
+
if z2 > 1:
|
102 |
+
warnings.warn('Z2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
|
103 |
+
z2 = 1
|
104 |
+
if (x2 - x1) * (y2 - y1) * (z2 - z1) == 0.0:
|
105 |
+
warnings.warn("Zero volume box skipped: {}.".format(box_part))
|
106 |
+
continue
|
107 |
+
|
108 |
+
b = [int(label), float(score) * weights[t], x1, y1, z1, x2, y2, z2]
|
109 |
+
if label not in new_boxes:
|
110 |
+
new_boxes[label] = []
|
111 |
+
new_boxes[label].append(b)
|
112 |
+
|
113 |
+
# Sort each list in dict by score and transform it to numpy array
|
114 |
+
for k in new_boxes:
|
115 |
+
current_boxes = np.array(new_boxes[k])
|
116 |
+
new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
|
117 |
+
|
118 |
+
return new_boxes
|
119 |
+
|
120 |
+
|
121 |
+
def get_weighted_box(boxes, conf_type='avg'):
|
122 |
+
"""
|
123 |
+
Create weighted box for set of boxes
|
124 |
+
:param boxes: set of boxes to fuse
|
125 |
+
:param conf_type: type of confidence one of 'avg' or 'max'
|
126 |
+
:return: weighted box
|
127 |
+
"""
|
128 |
+
|
129 |
+
box = np.zeros(8, dtype=np.float32)
|
130 |
+
conf = 0
|
131 |
+
conf_list = []
|
132 |
+
for b in boxes:
|
133 |
+
box[2:] += (b[1] * b[2:])
|
134 |
+
conf += b[1]
|
135 |
+
conf_list.append(b[1])
|
136 |
+
box[0] = boxes[0][0]
|
137 |
+
if conf_type == 'avg':
|
138 |
+
box[1] = conf / len(boxes)
|
139 |
+
elif conf_type == 'max':
|
140 |
+
box[1] = np.array(conf_list).max()
|
141 |
+
box[2:] /= conf
|
142 |
+
return box
|
143 |
+
|
144 |
+
|
145 |
+
def find_matching_box(boxes_list, new_box, match_iou):
|
146 |
+
best_iou = match_iou
|
147 |
+
best_index = -1
|
148 |
+
for i in range(len(boxes_list)):
|
149 |
+
box = boxes_list[i]
|
150 |
+
if box[0] != new_box[0]:
|
151 |
+
continue
|
152 |
+
iou = bb_intersection_over_union_3d(box[2:], new_box[2:])
|
153 |
+
if iou > best_iou:
|
154 |
+
best_index = i
|
155 |
+
best_iou = iou
|
156 |
+
|
157 |
+
return best_index, best_iou
|
158 |
+
|
159 |
+
|
160 |
+
def weighted_boxes_fusion_3d(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False):
|
161 |
+
'''
|
162 |
+
:param boxes_list: list of boxes predictions from each model, each box is 6 numbers.
|
163 |
+
It has 3 dimensions (models_number, model_preds, 6)
|
164 |
+
Order of boxes: x1, y1, z1, x2, y2 z2. We expect float normalized coordinates [0; 1]
|
165 |
+
:param scores_list: list of scores for each model
|
166 |
+
:param labels_list: list of labels for each model
|
167 |
+
:param weights: list of weights for each model. Default: None, which means weight == 1 for each model
|
168 |
+
:param iou_thr: IoU value for boxes to be a match
|
169 |
+
:param skip_box_thr: exclude boxes with score lower than this variable
|
170 |
+
:param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value
|
171 |
+
:param allows_overflow: false if we want confidence score not exceed 1.0
|
172 |
+
|
173 |
+
:return: boxes: boxes coordinates (Order of boxes: x1, y1, z1, x2, y2, z2).
|
174 |
+
:return: scores: confidence scores
|
175 |
+
:return: labels: boxes labels
|
176 |
+
'''
|
177 |
+
|
178 |
+
if weights is None:
|
179 |
+
weights = np.ones(len(boxes_list))
|
180 |
+
if len(weights) != len(boxes_list):
|
181 |
+
print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
|
182 |
+
weights = np.ones(len(boxes_list))
|
183 |
+
weights = np.array(weights)
|
184 |
+
|
185 |
+
if conf_type not in ['avg', 'max']:
|
186 |
+
print('Error. Unknown conf_type: {}. Must be "avg" or "max". Use "avg"'.format(conf_type))
|
187 |
+
conf_type = 'avg'
|
188 |
+
|
189 |
+
filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
|
190 |
+
if len(filtered_boxes) == 0:
|
191 |
+
return np.zeros((0, 6)), np.zeros((0,)), np.zeros((0,))
|
192 |
+
|
193 |
+
overall_boxes = []
|
194 |
+
for label in filtered_boxes:
|
195 |
+
boxes = filtered_boxes[label]
|
196 |
+
new_boxes = []
|
197 |
+
weighted_boxes = []
|
198 |
+
|
199 |
+
# Clusterize boxes
|
200 |
+
for j in range(0, len(boxes)):
|
201 |
+
index, best_iou = find_matching_box(weighted_boxes, boxes[j], iou_thr)
|
202 |
+
if index != -1:
|
203 |
+
new_boxes[index].append(boxes[j])
|
204 |
+
weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
|
205 |
+
else:
|
206 |
+
new_boxes.append([boxes[j].copy()])
|
207 |
+
weighted_boxes.append(boxes[j].copy())
|
208 |
+
|
209 |
+
# Rescale confidence based on number of models and boxes
|
210 |
+
for i in range(len(new_boxes)):
|
211 |
+
if not allows_overflow:
|
212 |
+
weighted_boxes[i][1] = weighted_boxes[i][1] * min(weights.sum(), len(new_boxes[i])) / weights.sum()
|
213 |
+
else:
|
214 |
+
weighted_boxes[i][1] = weighted_boxes[i][1] * len(new_boxes[i]) / weights.sum()
|
215 |
+
overall_boxes.append(np.array(weighted_boxes))
|
216 |
+
|
217 |
+
overall_boxes = np.concatenate(overall_boxes, axis=0)
|
218 |
+
overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
|
219 |
+
boxes = overall_boxes[:, 2:]
|
220 |
+
scores = overall_boxes[:, 1]
|
221 |
+
labels = overall_boxes[:, 0]
|
222 |
+
return boxes, scores, labels
|
DenseMammogram/experimenter.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experimenter Class is responsible for mainly four things:
|
2 |
+
# 1. Configuration - Done
|
3 |
+
# 2. Logging using the AdvancedLogger class - Almost Done
|
4 |
+
# 3. Model Handling, including loading and saving models - Done(Upgrades Left)
|
5 |
+
# 4. Running Different Variants Paralelly/Sequentially of experiments
|
6 |
+
# 5. Combining frcnn training followed by bilateral training and final froc calculation - Done
|
7 |
+
# 6. Version Control
|
8 |
+
|
9 |
+
from advanced_config import AdvancedConfig
|
10 |
+
from advanced_logger import AdvancedLogger, LogPriority
|
11 |
+
import os
|
12 |
+
from os.path import join
|
13 |
+
from plot_froc import plot_froc
|
14 |
+
from train_frcnn import main as TRAIN_FRCNN
|
15 |
+
from train_bilateral import main as TRAIN_BILATERAL
|
16 |
+
import torch
|
17 |
+
from model_utils import generate_predictions, generate_predictions_bilateral
|
18 |
+
import argparse
|
19 |
+
from dataloaders import get_dict
|
20 |
+
from utils import create_backup
|
21 |
+
from torch.utils.tensorboard import SummaryWriter
|
22 |
+
|
23 |
+
class Experimenter:
|
24 |
+
|
25 |
+
def __init__(self, cfg_file, BASE_DIR = 'experiments'):
|
26 |
+
self.cfg_file = cfg_file
|
27 |
+
|
28 |
+
self.con = AdvancedConfig(cfg_file)
|
29 |
+
self.config = self.con.config
|
30 |
+
self.exp_dir = join(BASE_DIR,self.config['EXP_NAME'])
|
31 |
+
os.makedirs(self.exp_dir, exist_ok=True)
|
32 |
+
self.con.save(join(self.exp_dir,'config.cfg'))
|
33 |
+
|
34 |
+
self.logger = AdvancedLogger(self.exp_dir)
|
35 |
+
self.logger.log('Experiment:',self.config['EXP_NAME'],priority = LogPriority.STATS)
|
36 |
+
self.logger.log('Experiment Description:', self.config['EXP_DESC'], priority = LogPriority.STATS)
|
37 |
+
self.logger.log('Config File:',self.cfg_file, priority = LogPriority.STATS)
|
38 |
+
self.logger.log('Experiment started', priority = LogPriority.LOW)
|
39 |
+
self.losses = dict()
|
40 |
+
self.frocs = dict()
|
41 |
+
|
42 |
+
self.writer = SummaryWriter(join(self.exp_dir,'tensor_logs'))
|
43 |
+
|
44 |
+
create_backup(backup_dir=join(self.exp_dir,'scripts'))
|
45 |
+
|
46 |
+
def log(self, *args, **kwargs):
|
47 |
+
self.logger.log(*args, **kwargs)
|
48 |
+
|
49 |
+
|
50 |
+
def init_losses(self,mode):
|
51 |
+
if mode == 'FRCNN' or mode == 'FRCNN_BILATERAL':
|
52 |
+
self.losses['frcnn_loss'] = []
|
53 |
+
self.frocs['frcnn_froc'] = []
|
54 |
+
elif mode == 'BILATERAL' or mode == 'FRCNN_BILATERAL':
|
55 |
+
self.losses['bilateral_loss'] = []
|
56 |
+
self.frocs['bilateral_froc'] = []
|
57 |
+
|
58 |
+
def start_epoch(self):
|
59 |
+
self.curr_epoch += 1
|
60 |
+
self.logger.log('Epoch:',self.curr_epoch, priority = LogPriority.MEDIUM)
|
61 |
+
|
62 |
+
def end_epoch(self, loss, model = None, device = None):
|
63 |
+
if self.curr_mode == 'FRCNN':
|
64 |
+
self.losses['frcnn_loss'].append(loss)
|
65 |
+
self.best_loss = min(self.losses['frcnn_loss'])
|
66 |
+
if self.config['EVAL_METHOD'] == 'FROC':
|
67 |
+
exp_name = self.config['EXP_NAME']
|
68 |
+
_, val_path, _ = self.init_paths()
|
69 |
+
generate_predictions(model,device,val_path,f'preds_frcnn_{exp_name}')
|
70 |
+
from froc_by_pranjal import get_froc_points
|
71 |
+
senses, _ = get_froc_points(f'preds_frcnn_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']), fps_req = [0.2])
|
72 |
+
self.frocs['frcnn_froc'].append(senses[0])
|
73 |
+
self.best_froc = max(self.frocs['frcnn_froc'])
|
74 |
+
self.logger.log(f'Val FROC: {senses[0]}', LogPriority.MEDIUM)
|
75 |
+
self.logger.log(f'Best FROC: {self.best_froc}')
|
76 |
+
elif self.curr_mode == 'BILATERAL':
|
77 |
+
self.losses['bilateral_loss'].append(loss)
|
78 |
+
self.best_loss = min(self.losses['bilateral_loss'])
|
79 |
+
if self.config['EVAL_METHOD'] == 'FROC':
|
80 |
+
exp_name = self.config['EXP_NAME']
|
81 |
+
_, val_path, _ = self.init_paths()
|
82 |
+
data_dir = self.config['DATA_DIR']
|
83 |
+
print('Generating')
|
84 |
+
generate_predictions_bilateral(model,device,val_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),preds_folder = f'preds_bilateral_{exp_name}')
|
85 |
+
print('Generation Done')
|
86 |
+
from froc_by_pranjal import get_froc_points
|
87 |
+
senses, _ = get_froc_points(f'preds_bilateral_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']), fps_req = [0.1])
|
88 |
+
print('Reading Sens from',f'preds_bilateral_{exp_name}', join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']),)
|
89 |
+
|
90 |
+
self.frocs['bilateral_froc'].append(senses[0])
|
91 |
+
self.best_froc = max(self.frocs['bilateral_froc'])
|
92 |
+
self.logger.log(f'Val FROC: {senses[0]}', priority = LogPriority.MEDIUM)
|
93 |
+
self.logger.log(f'Best FROC: {self.best_froc}')
|
94 |
+
|
95 |
+
self.writer.add_scalar(f"{self.curr_mode}/Loss/Valid", loss, self.curr_epoch)
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
def save_model(self, model):
|
100 |
+
if self.curr_mode == 'FRCNN':
|
101 |
+
self.logger.log('Saving FRCNN Model', priority = LogPriority.LOW)
|
102 |
+
model_file = join(self.exp_dir,'frcnn_models',f'frcnn_model.pth')
|
103 |
+
if self.config['EVAL_METHOD']:
|
104 |
+
SAVE = self.best_froc == self.frocs['frcnn_froc'][-1]
|
105 |
+
else:
|
106 |
+
SAVE = self.best_loss == self.losses['frcnn_loss'][-1]
|
107 |
+
elif self.curr_mode == 'BILATERAL':
|
108 |
+
self.logger.log('Saving Bilateral Model', priority = LogPriority.LOW)
|
109 |
+
model_file = join(self.exp_dir,'bilateral_models',f'bilateral_model.pth')
|
110 |
+
if self.config['EVAL_METHOD'] == 'FROC':
|
111 |
+
SAVE = self.best_froc == self.frocs['bilateral_froc'][-1]
|
112 |
+
else:
|
113 |
+
SAVE = self.best_loss == self.losses['bilateral_loss'][-1]
|
114 |
+
os.makedirs(os.path.split(model_file)[0], exist_ok=True)
|
115 |
+
if SAVE:
|
116 |
+
torch.save(model.state_dict(), model_file)
|
117 |
+
|
118 |
+
torch.save(model.state_dict(), f'{model_file[:-4]}_{self.curr_epoch}.pth')
|
119 |
+
|
120 |
+
def init_paths(self,):
|
121 |
+
train_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TRAIN_SPLIT'])
|
122 |
+
val_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT'])
|
123 |
+
test_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT'])
|
124 |
+
return train_path, val_path, test_path
|
125 |
+
|
126 |
+
def abs_path(self, path):
|
127 |
+
return join(self.config['DATA_DIR'], path)
|
128 |
+
|
129 |
+
# Impure Function, upadtes the model with best state dicts
|
130 |
+
def generate_predictions(self,model, device):
|
131 |
+
self.logger.log('Generating Predictions')
|
132 |
+
self.logger.flush()
|
133 |
+
exp_name = self.config['EXP_NAME']
|
134 |
+
train_path, val_path, test_path = self.init_paths()
|
135 |
+
|
136 |
+
# Load the best val_loss model's state dicts
|
137 |
+
if self.curr_mode == 'FRCNN':
|
138 |
+
model_file = join(self.exp_dir,'frcnn_models','frcnn_model.pth')
|
139 |
+
elif self.curr_mode == 'BILATERAL':
|
140 |
+
model_file = join(self.exp_dir,'bilateral_models','bilateral_model.pth')
|
141 |
+
model.load_state_dict(torch.load(model_file))
|
142 |
+
|
143 |
+
if self.curr_mode == 'FRCNN':
|
144 |
+
generate_predictions(model,device,train_path,f'preds_frcnn_{exp_name}')
|
145 |
+
generate_predictions(model,device,val_path,f'preds_frcnn_{exp_name}')
|
146 |
+
generate_predictions(model,device,test_path,f'preds_frcnn_{exp_name}')
|
147 |
+
elif self.curr_mode == 'BILATERAL':
|
148 |
+
data_dir = self.config['DATA_DIR']
|
149 |
+
generate_predictions_bilateral(model,device,train_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}')
|
150 |
+
generate_predictions_bilateral(model,device,val_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}')
|
151 |
+
generate_predictions_bilateral(model,device,test_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}')
|
152 |
+
test_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT'])
|
153 |
+
|
154 |
+
def run_experiment(self):
|
155 |
+
|
156 |
+
# First Determine the mode of running the experiment
|
157 |
+
mode = self.config['MODE']
|
158 |
+
self.init_losses(mode)
|
159 |
+
self.curr_mode = 'FRCNN'
|
160 |
+
self.curr_epoch = -1
|
161 |
+
self.best_loss = 999999
|
162 |
+
self.best_froc = 0
|
163 |
+
if mode == 'FRCNN':
|
164 |
+
TRAIN_FRCNN(self.config['FRCNN'], self)
|
165 |
+
elif mode == 'BILATERAL':
|
166 |
+
self.curr_mode = 'BILATERAL'
|
167 |
+
TRAIN_BILATERAL(self.config['BILATERAL'], self)
|
168 |
+
elif mode == 'FRCNN_BILATERAL':
|
169 |
+
TRAIN_FRCNN(self.config['FRCNN'], self)
|
170 |
+
self.curr_mode = 'BILATERAL'
|
171 |
+
self.curr_epoch = -1
|
172 |
+
self.best_loss = 999999
|
173 |
+
# Note the path to frcnn model must be the same as that dictated by experiment
|
174 |
+
self.config['BILATERAL']['FRCNN_MODEL_PATH'] = join(self.exp_dir,'frcnn_models','frcnn_model.pth')
|
175 |
+
TRAIN_BILATERAL(self.config['BILATERAL'], self)
|
176 |
+
|
177 |
+
self.logger.log(f'Best Loss: {self.best_loss}', priority= LogPriority.STATS)
|
178 |
+
self.logger.log('Experiment Training and Generation Ended', priority = LogPriority.MEDIUM)
|
179 |
+
|
180 |
+
# Now evaluate the results
|
181 |
+
|
182 |
+
frcnn_file = join(self.exp_dir, 'senses_fps_frcnn.txt')
|
183 |
+
bilateral_file = join(self.exp_dir, 'senses_fps_bilateral.txt')
|
184 |
+
from froc_by_pranjal import get_froc_points
|
185 |
+
exp_name = self.config['EXP_NAME']
|
186 |
+
if mode == 'FRCNN' or mode == 'FRCNN_BILATERAL':
|
187 |
+
senses, fps = get_froc_points(f'preds_frcnn_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT']), save_to = frcnn_file)
|
188 |
+
self.logger.log('FRCNN RESULTS', priority = LogPriority.STATS)
|
189 |
+
for s,f in zip(senses, fps):
|
190 |
+
self.logger.log(f'Sensitivty at {f}: {s}', priority = LogPriority.STATS)
|
191 |
+
if mode == 'BILATERAL' or mode == 'FRCNN_BILATERAL':
|
192 |
+
senses, fps = get_froc_points(f'preds_bilateral_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT']), save_to = bilateral_file)
|
193 |
+
self.logger.log('BILATERAL RESULTS', priority = LogPriority.STATS)
|
194 |
+
for s,f in zip(senses, fps):
|
195 |
+
self.logger.log(f'Sensitivty at {f}: {s}', priority = LogPriority.STATS)
|
196 |
+
|
197 |
+
|
198 |
+
# Now draw the graphs.... If FRCNN and BILATERAL both done, draw them on one graph
|
199 |
+
# Else draw single graphs only
|
200 |
+
if mode == 'FRCNN':
|
201 |
+
plot_froc({frcnn_file : 'FRCNN'}, join(self.exp_dir,'plot.png'), TITLE = 'FRCNN FROC')
|
202 |
+
elif mode == 'BILATERAL':
|
203 |
+
plot_froc({bilateral_file : 'BILATERAL'}, join(self.exp_dir,'plot.png'), TITLE = 'BILATERAL FROC')
|
204 |
+
elif mode == 'FRCNN_BILATERAL':
|
205 |
+
plot_froc({frcnn_file : 'FRCNN', bilateral_file : 'BILATERAL'}, join(self.exp_dir,'plot.png'), TITLE = 'FRCNN vs BILATERAL FROC')
|
206 |
+
self.logger.flush()
|
207 |
+
|
208 |
+
if __name__ == '__main__':
|
209 |
+
parser = argparse.ArgumentParser()
|
210 |
+
parser.add_argument('--cfg_file', type=str, default='configs/AIIMS_C1.cfg')
|
211 |
+
args = parser.parse_args()
|
212 |
+
exp = Experimenter(args.cfg_file)
|
213 |
+
exp.run_experiment()
|
DenseMammogram/froc_by_pranjal.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import sys
|
4 |
+
from os.path import join
|
5 |
+
|
6 |
+
|
7 |
+
'''
|
8 |
+
Note: Anywhere empty boxes means [] and not [[]]
|
9 |
+
'''
|
10 |
+
|
11 |
+
|
12 |
+
def remove_true_positives(gts, preds):
|
13 |
+
|
14 |
+
def true_positive(gt, pred):
|
15 |
+
# If center of pred is inside the gt, it is a true positive
|
16 |
+
c_pred = ((pred[0]+pred[2])/2., (pred[1]+pred[3])/2.)
|
17 |
+
if (c_pred[0] >= gt[0] and c_pred[0] <= gt[2] and
|
18 |
+
c_pred[1] >= gt[1] and c_pred[1] <= gt[3]):
|
19 |
+
return True
|
20 |
+
return False
|
21 |
+
|
22 |
+
tps = 0
|
23 |
+
fns = 0
|
24 |
+
|
25 |
+
for gt in gts:
|
26 |
+
# First check if any true positive exists
|
27 |
+
# If more than one exists, do not include it in next set of preds
|
28 |
+
add_tp = False
|
29 |
+
new_preds = []
|
30 |
+
for pred in preds:
|
31 |
+
if true_positive(gt, pred):
|
32 |
+
add_tp = True
|
33 |
+
else:
|
34 |
+
new_preds.append(pred)
|
35 |
+
preds = new_preds
|
36 |
+
if add_tp:
|
37 |
+
tps += 1
|
38 |
+
else:
|
39 |
+
fns += 1
|
40 |
+
return preds, tps, fns
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def calc_metric_single(gts, preds, threshold,):
|
45 |
+
'''
|
46 |
+
Returns fp, tp, tn, fn
|
47 |
+
'''
|
48 |
+
preds = list(filter(lambda x: x[0] >= threshold, preds))
|
49 |
+
preds = [pred[1:] for pred in preds] # Remove the scores
|
50 |
+
|
51 |
+
if len(gts) == 0:
|
52 |
+
return len(preds), 0, 1 if len(preds) == 0 else 0, 0
|
53 |
+
preds, tps, fns = remove_true_positives(gts, preds)
|
54 |
+
# All remaining will have to fps
|
55 |
+
fps = len(preds)
|
56 |
+
return fps, tps, 0, fns
|
57 |
+
|
58 |
+
|
59 |
+
def calc_metrics_at_thresh(im_dict, threshold):
|
60 |
+
'''
|
61 |
+
Returns fp, tp, tn, fn
|
62 |
+
'''
|
63 |
+
fps, tps, tns, fns = 0, 0, 0, 0
|
64 |
+
for key in im_dict:
|
65 |
+
fp,tp,tn,fn = calc_metric_single(im_dict[key]['gt'],
|
66 |
+
im_dict[key]['preds'], threshold)
|
67 |
+
fps+=fp
|
68 |
+
tps+=tp
|
69 |
+
tns+=tn
|
70 |
+
fns+=fn
|
71 |
+
|
72 |
+
return fps, tps, tns, fns
|
73 |
+
|
74 |
+
from joblib import Parallel, delayed
|
75 |
+
|
76 |
+
def calc_metrics(inp):
|
77 |
+
im_dict, tr = inp
|
78 |
+
out = dict()
|
79 |
+
for t in tr:
|
80 |
+
fp, tp, tn, fn = calc_metrics_at_thresh(im_dict, t)
|
81 |
+
out[t] = [fp, tp, tn, fn]
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
def calc_froc_from_dict(im_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3], save_to = None):
|
86 |
+
|
87 |
+
num_images = len(im_dict)
|
88 |
+
|
89 |
+
gap = 0.005
|
90 |
+
n = int(1/gap)
|
91 |
+
thresholds = [i * gap for i in range(n)]
|
92 |
+
fps = [0 for _ in range(n)]
|
93 |
+
tps = [0 for _ in range(n)]
|
94 |
+
tns = [0 for _ in range(n)]
|
95 |
+
fns = [0 for _ in range(n)]
|
96 |
+
|
97 |
+
|
98 |
+
for i,t in enumerate(thresholds):
|
99 |
+
fps[i], tps[i], tns[i], fns[i] = calc_metrics_at_thresh(im_dict, t)
|
100 |
+
|
101 |
+
|
102 |
+
# Now calculate the sensitivities
|
103 |
+
senses = []
|
104 |
+
for t,f in zip(tps, fns):
|
105 |
+
try: senses.append(t/(t+f))
|
106 |
+
except: senses.append(0.)
|
107 |
+
|
108 |
+
if save_to is not None:
|
109 |
+
f = open(save_to, 'w')
|
110 |
+
for fp,s in zip(fps, senses):
|
111 |
+
f.write(f'{fp/num_images} {s}\n')
|
112 |
+
f.close()
|
113 |
+
|
114 |
+
senses_req = []
|
115 |
+
for fp_req in fps_req:
|
116 |
+
for i,f in enumerate(fps):
|
117 |
+
if f/num_images < fp_req:
|
118 |
+
if fp_req == 0.1:
|
119 |
+
print(fps[i], tps[i], tns[i], fns[i])
|
120 |
+
prec = tps[i]/(tps[i] + fps[i])
|
121 |
+
recall = tps[i]/(tps[i] + fns[i])
|
122 |
+
f1 = 2*prec*recall/(prec+recall)
|
123 |
+
spec = tns[i]/ (tns[i] + fps[i])
|
124 |
+
print(f'Specificity: {spec}')
|
125 |
+
print(f'Precision: {prec}')
|
126 |
+
print(f'Recall: {recall}')
|
127 |
+
print(f'F1: {f1}')
|
128 |
+
senses_req.append(senses[i-1])
|
129 |
+
break
|
130 |
+
return senses_req, fps_req
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
def file_to_bbox(file_name):
|
136 |
+
try:
|
137 |
+
content = open(file_name, 'r').readlines()
|
138 |
+
st = 0
|
139 |
+
if len(content) == 0:
|
140 |
+
# Empty File Should Return []
|
141 |
+
return []
|
142 |
+
if content[0].split()[0].isalpha():
|
143 |
+
st = 1
|
144 |
+
return [[float(x) for x in line.split()[st:]] for line in content]
|
145 |
+
except FileNotFoundError:
|
146 |
+
print(f'No Corresponding Box Found for file {file_name}, using [] as preds')
|
147 |
+
return []
|
148 |
+
except Exception as e:
|
149 |
+
print('Some Error',e)
|
150 |
+
return []
|
151 |
+
|
152 |
+
def generate_image_dict(preds_folder_name='preds_42',
|
153 |
+
root_fol='/home/pranjal/densebreeast_datasets/AIIMS_C1',
|
154 |
+
mal_path=None, ben_path=None, gt_path=None,
|
155 |
+
mal_img_path = None, ben_img_path = None
|
156 |
+
):
|
157 |
+
|
158 |
+
mal_path = join(root_fol, mal_path) if mal_path else join(
|
159 |
+
root_fol, 'mal', preds_folder_name)
|
160 |
+
ben_path = join(root_fol, ben_path) if ben_path else join(
|
161 |
+
root_fol, 'ben', preds_folder_name)
|
162 |
+
mal_img_path = join(root_fol, mal_img_path) if mal_img_path else join(
|
163 |
+
root_fol, 'mal', 'images')
|
164 |
+
ben_img_path = join(root_fol, ben_img_path) if ben_img_path else join(
|
165 |
+
root_fol, 'ben', 'images')
|
166 |
+
gt_path = join(root_fol, gt_path) if gt_path else join(
|
167 |
+
root_fol, 'mal', 'gt')
|
168 |
+
|
169 |
+
|
170 |
+
'''
|
171 |
+
image_dict structure:
|
172 |
+
'image_name(without txt/png)' : {'gt' : [[...]], 'preds' : [[]]}
|
173 |
+
'''
|
174 |
+
image_dict = dict()
|
175 |
+
|
176 |
+
# GT Might be sightly different from images, therefore we will index gts based on
|
177 |
+
# the images folder instead.
|
178 |
+
for file in os.listdir(mal_img_path):
|
179 |
+
if not file.endswith('.png'):
|
180 |
+
continue
|
181 |
+
file = file[:-4] + '.txt'
|
182 |
+
file = join(gt_path, file)
|
183 |
+
key = os.path.split(file)[-1][:-4]
|
184 |
+
image_dict[key] = dict()
|
185 |
+
image_dict[key]['gt'] = file_to_bbox(file)
|
186 |
+
image_dict[key]['preds'] = []
|
187 |
+
|
188 |
+
for file in glob.glob(join(mal_path, '*.txt')):
|
189 |
+
key = os.path.split(file)[-1][:-4]
|
190 |
+
assert key in image_dict
|
191 |
+
image_dict[key]['preds'] = file_to_bbox(file)
|
192 |
+
|
193 |
+
for file in os.listdir(ben_img_path):
|
194 |
+
if not file.endswith('.png'):
|
195 |
+
continue
|
196 |
+
|
197 |
+
file = file[:-4] + '.txt'
|
198 |
+
file = join(ben_path, file)
|
199 |
+
key = os.path.split(file)[-1][:-4]
|
200 |
+
if key == 'Calc-Test_P_00353_LEFT_CC' or key == 'Calc-Training_P_00600_LEFT_CC': # Corrupt Files in Dataset
|
201 |
+
continue
|
202 |
+
if key in image_dict:
|
203 |
+
print(key)
|
204 |
+
# assert key not in image_dict
|
205 |
+
if key in image_dict:
|
206 |
+
print(f'Unexpected Error. {key} exists in multiple splits')
|
207 |
+
continue
|
208 |
+
image_dict[key] = dict()
|
209 |
+
image_dict[key]['preds'] = file_to_bbox(file)
|
210 |
+
image_dict[key]['gt'] = []
|
211 |
+
return image_dict
|
212 |
+
|
213 |
+
|
214 |
+
def pretty_print_fps(senses,fps):
|
215 |
+
for s,f in zip(senses,fps):
|
216 |
+
print(f'Sensitivty at {f}: {s}')
|
217 |
+
|
218 |
+
def get_froc_points(preds_image_folder, root_fol, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3], save_to = None):
|
219 |
+
im_dict = generate_image_dict(preds_image_folder, root_fol = root_fol)
|
220 |
+
# print(im_dict)
|
221 |
+
print(len(im_dict))
|
222 |
+
senses, fps = calc_froc_from_dict(im_dict, fps_req, save_to = save_to)
|
223 |
+
return senses, fps
|
224 |
+
|
225 |
+
if __name__ == '__main__':
|
226 |
+
seed = '42' if len(sys.argv)== 1 else sys.argv[1]
|
227 |
+
|
228 |
+
root_fol = '../bilateral_new/MammoDatasets/AIIMS_highres_reliable/test_2'
|
229 |
+
|
230 |
+
if len(sys.argv) <= 2:
|
231 |
+
save_to = None
|
232 |
+
else:
|
233 |
+
save_to = sys.argv[2]
|
234 |
+
senses, fps = get_froc_points(f'preds_{seed}',root_fol, save_to = save_to)
|
235 |
+
|
236 |
+
pretty_print_fps(senses, fps)
|
DenseMammogram/geenerate_aiims.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from os.path import join
|
4 |
+
from model_utils import generate_predictions, generate_predictions_bilateral
|
5 |
+
from models import get_FRCNN_model, Bilateral_model
|
6 |
+
from froc_by_pranjal import get_froc_points
|
7 |
+
from auc_by_pranjal import get_auc_score
|
8 |
+
|
9 |
+
####### PARAMETERS TO ADJUST #######
|
10 |
+
exp_name = 'BILATERAL'
|
11 |
+
OUT_FILE = 'aiims_full_test_results/bil_complete.txt'
|
12 |
+
BILATERAL = True
|
13 |
+
dataset_path = 'AIIMS_highres_reliable/test_2'
|
14 |
+
####################################
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
if os.path.split(OUT_FILE)[0]:
|
20 |
+
os.makedirs(os.path.split(OUT_FILE)[0], exist_ok=True)
|
21 |
+
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
frcnn_model = get_FRCNN_model().to(device)
|
24 |
+
|
25 |
+
if BILATERAL:
|
26 |
+
model = Bilateral_model(frcnn_model).to(device)
|
27 |
+
MODEL_PATH = f'experiments/{exp_name}/bilateral_models/bilateral_model.pth'
|
28 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
29 |
+
else:
|
30 |
+
model = frcnn_model
|
31 |
+
MODEL_PATH = f'experiments/{exp_name}/frcnn_models/frcnn_model.pth'
|
32 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
33 |
+
|
34 |
+
|
35 |
+
test_path = join('../bilateral_new', 'MammoDatasets',dataset_path)
|
36 |
+
|
37 |
+
|
38 |
+
def get_aiims_dict(test_path, corr_file):
|
39 |
+
extract_file = lambda x: x[x.find('test_2/')+7:]
|
40 |
+
corr_dict = {extract_file(line.split()[0].replace('"','')):extract_file(line.split()[1].replace('"','')) for line in open(corr_file).readlines()}
|
41 |
+
corr_dict = {join(test_path,k):join(test_path,v) for k,v in corr_dict.items()}
|
42 |
+
return corr_dict
|
43 |
+
|
44 |
+
if BILATERAL:
|
45 |
+
pred_dir = f'preds_bilateral_{exp_name}'
|
46 |
+
generate_predictions_bilateral(model,device,test_path, get_aiims_dict(test_path, '../bilateral_new/corr_lists/aiims_corr_list_with_val_full_test.txt'),'aiims',pred_dir)
|
47 |
+
else:
|
48 |
+
pred_dir = f'preds_frcnn_{exp_name}'
|
49 |
+
generate_predictions(model, device, test_path, preds_folder = pred_dir)
|
50 |
+
|
51 |
+
|
52 |
+
file = open(OUT_FILE, 'a')
|
53 |
+
file.writelines(f'{exp_name} FROC Score:\n')
|
54 |
+
senses, fps = get_froc_points(pred_dir, root_fol= test_path, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.0,1.5])
|
55 |
+
for s,f in zip(senses, fps):
|
56 |
+
print(f'Sensitivty at {f}: {s}')
|
57 |
+
file.writelines(f'Sensitivty at {f}: {s}\n')
|
58 |
+
file.close()
|
59 |
+
|
60 |
+
print('AUC Score:',get_auc_score(pred_dir, test_path, retAcc = True, acc_thresh = 1.))
|
61 |
+
|
DenseMammogram/geenerate_ddsm_preds.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from os.path import join
|
4 |
+
from model_utils import generate_predictions, generate_predictions_bilateral
|
5 |
+
from models import get_FRCNN_model, Bilateral_model
|
6 |
+
from froc_by_pranjal import get_froc_points
|
7 |
+
from auc_by_pranjal import get_auc_score
|
8 |
+
|
9 |
+
####### PARAMETERS TO ADJUST #######
|
10 |
+
exp_name = 'frcnn_16'
|
11 |
+
OUT_FILE = 'ddsm_results/ddsm_dset.txt'
|
12 |
+
BILATERAL = False
|
13 |
+
dataset_path = 'ddsm_data_no_proc_2100_nocrop/val'
|
14 |
+
####################################
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
if os.path.split(OUT_FILE)[0]:
|
20 |
+
os.makedirs(os.path.split(OUT_FILE)[0], exist_ok=True)
|
21 |
+
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
frcnn_model = get_FRCNN_model().to(device)
|
24 |
+
|
25 |
+
if BILATERAL:
|
26 |
+
model = Bilateral_model(frcnn_model).to(device)
|
27 |
+
MODEL_PATH = f'experiments/{exp_name}/bilateral_models/bilateral_model.pth'
|
28 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
29 |
+
else:
|
30 |
+
model = frcnn_model
|
31 |
+
MODEL_PATH = f'experiments/{exp_name}/frcnn_models/frcnn_model.pth'
|
32 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
33 |
+
|
34 |
+
|
35 |
+
test_path = join('../bilateral_new', 'MammoDatasets',dataset_path)
|
36 |
+
|
37 |
+
|
38 |
+
def get_ddsm_dict(test_path, corr_file):
|
39 |
+
extract_file = lambda x: x[x.find('val/')+4:]
|
40 |
+
corr_dict = {extract_file(line.split()[0].replace('"','')):extract_file(line.split()[1].replace('"','')) for line in open(corr_file).readlines()}
|
41 |
+
corr_dict = {join(test_path,k):join(test_path,v) for k,v in corr_dict.items()}
|
42 |
+
return corr_dict
|
43 |
+
|
44 |
+
if BILATERAL:
|
45 |
+
pred_dir = f'preds_bilateral_{exp_name}'
|
46 |
+
generate_predictions_bilateral(model,device,test_path, get_ddsm_dict(test_path, '../bilateral_new/corr_lists/ddsm_corr_list_with_val.txt'),'ddsm',pred_dir)
|
47 |
+
else:
|
48 |
+
pred_dir = f'preds_frcnn_{exp_name}'
|
49 |
+
generate_predictions(model, device, test_path, preds_folder = pred_dir)
|
50 |
+
|
51 |
+
|
52 |
+
file = open(OUT_FILE, 'a')
|
53 |
+
file.writelines(f'{exp_name} FROC Score:\n')
|
54 |
+
senses, fps = get_froc_points(pred_dir, root_fol= test_path, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.0,1.5])
|
55 |
+
for s,f in zip(senses, fps):
|
56 |
+
print(f'Sensitivty at {f}: {s}')
|
57 |
+
file.writelines(f'Sensitivty at {f}: {s}\n')
|
58 |
+
file.close()
|
59 |
+
|
60 |
+
print('AUC Score:',get_auc_score(pred_dir, test_path, retAcc = True, acc_thresh = 1.))
|
61 |
+
|
DenseMammogram/geenerate_inbreast_preds.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from os.path import join
|
4 |
+
from model_utils import generate_predictions, generate_predictions_bilateral
|
5 |
+
from models import get_FRCNN_model, Bilateral_model
|
6 |
+
from froc_by_pranjal import get_froc_points
|
7 |
+
|
8 |
+
####### PARAMETERS TO ADJUST #######
|
9 |
+
exp_name = 'AIIMS_C3'
|
10 |
+
OUT_FILE = 'ib_results/c3_frcnn.txt'
|
11 |
+
BILATERAL = False
|
12 |
+
dataset_path = 'INBREAST_C3/test'
|
13 |
+
####################################
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
if os.path.split(OUT_FILE)[0]:
|
19 |
+
os.makedirs(os.path.split(OUT_FILE)[0], exist_ok=True)
|
20 |
+
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
frcnn_model = get_FRCNN_model().to(device)
|
23 |
+
|
24 |
+
if BILATERAL:
|
25 |
+
model = Bilateral_model(frcnn_model).to(device)
|
26 |
+
MODEL_PATH = f'experiments/{exp_name}/bilateral_models/bilateral_model.pth'
|
27 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
28 |
+
else:
|
29 |
+
model = frcnn_model
|
30 |
+
MODEL_PATH = f'experiments/{exp_name}/frcnn_models/frcnn_model.pth'
|
31 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
32 |
+
|
33 |
+
|
34 |
+
test_path = join('../bilateral_new', 'MammoDatasets',dataset_path)
|
35 |
+
|
36 |
+
|
37 |
+
def get_inbreast_dict(test_path, corr_file):
|
38 |
+
extract_file = lambda x: x[x.find('test/')+5:]
|
39 |
+
corr_dict = {extract_file(line.split()[0]):extract_file(line.split()[1]) for line in open(corr_file).readlines()}
|
40 |
+
corr_dict = {join(test_path,k):join(test_path,v) for k,v in corr_dict.items()}
|
41 |
+
return corr_dict
|
42 |
+
|
43 |
+
if BILATERAL:
|
44 |
+
pred_dir = f'preds_bilateral_{exp_name}'
|
45 |
+
generate_predictions_bilateral(model,device,test_path, get_inbreast_dict(test_path, '../bilateral_new/corr_lists/Inbreast_final_correspondence_list.txt'),'inbreast',pred_dir)
|
46 |
+
else:
|
47 |
+
pred_dir = f'preds_frcnn_{exp_name}'
|
48 |
+
generate_predictions(model, device, test_path, preds_folder = pred_dir)
|
49 |
+
|
50 |
+
|
51 |
+
file = open(OUT_FILE, 'a')
|
52 |
+
file.writelines(f'{exp_name} FROC Score:\n')
|
53 |
+
senses, fps = get_froc_points(pred_dir, root_fol= test_path)
|
54 |
+
for s,f in zip(senses, fps):
|
55 |
+
file.writelines(f'Sensitivty at {f}: {s}\n')
|
56 |
+
file.close()
|
57 |
+
|
DenseMammogram/geenerate_irch.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from os.path import join
|
4 |
+
from model_utils import generate_predictions, generate_predictions_bilateral
|
5 |
+
from models import get_FRCNN_model, Bilateral_model
|
6 |
+
from froc_by_pranjal import get_froc_points
|
7 |
+
from auc_by_pranjal import get_auc_score
|
8 |
+
|
9 |
+
####### PARAMETERS TO ADJUST #######
|
10 |
+
exp_name = 'BILATERAL'
|
11 |
+
OUT_FILE = 'irchvalres/bil_final.txt'
|
12 |
+
BILATERAL = True
|
13 |
+
dataset_path = 'IRCHVal'
|
14 |
+
####################################
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
if os.path.split(OUT_FILE)[0]:
|
20 |
+
os.makedirs(os.path.split(OUT_FILE)[0], exist_ok=True)
|
21 |
+
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
frcnn_model = get_FRCNN_model().to(device)
|
24 |
+
|
25 |
+
if BILATERAL:
|
26 |
+
model = Bilateral_model(frcnn_model).to(device)
|
27 |
+
MODEL_PATH = f'experiments/{exp_name}/bilateral_models/bilateral_model.pth'
|
28 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
29 |
+
else:
|
30 |
+
model = frcnn_model
|
31 |
+
MODEL_PATH = f'experiments/{exp_name}/frcnn_models/frcnn_model.pth'
|
32 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
33 |
+
|
34 |
+
|
35 |
+
test_path = join('../bilateral_new', 'MammoDatasets',dataset_path)
|
36 |
+
|
37 |
+
|
38 |
+
def get_aiims_dict(test_path, corr_file):
|
39 |
+
extract_file = lambda x: x
|
40 |
+
corr_dict = {extract_file(line.split('" "')[0].strip().replace('"','')):extract_file(line.split('" "')[1].strip().replace('"','')) for line in open(corr_file).readlines()}
|
41 |
+
corr_dict = {join(test_path,k):join(test_path,v) for k,v in corr_dict.items()}
|
42 |
+
print(list(corr_dict.keys())[:20])
|
43 |
+
return corr_dict
|
44 |
+
|
45 |
+
if BILATERAL:
|
46 |
+
pred_dir = f'preds_bilateral_{exp_name}'
|
47 |
+
generate_predictions_bilateral(model,device,test_path, get_aiims_dict(test_path, '../bilateral_new/corr_lists/irch_val.txt'),'irch',pred_dir)
|
48 |
+
else:
|
49 |
+
pred_dir = f'preds_frcnn_{exp_name}'
|
50 |
+
generate_predictions(model, device, test_path, preds_folder = pred_dir)
|
51 |
+
|
52 |
+
|
53 |
+
file = open(OUT_FILE, 'a')
|
54 |
+
file.writelines(f'{exp_name} FROC Score:\n')
|
55 |
+
senses, fps = get_froc_points(pred_dir, root_fol= test_path, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.0,1.5])
|
56 |
+
for s,f in zip(senses, fps):
|
57 |
+
print(f'Sensitivty at {f}: {s}')
|
58 |
+
file.writelines(f'Sensitivty at {f}: {s}\n')
|
59 |
+
file.close()
|
60 |
+
|
61 |
+
print('AUC Score:',get_auc_score(pred_dir, test_path, retAcc = True, acc_thresh = 1.))
|
62 |
+
|
DenseMammogram/merge_predictions.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
from os.path import join
|
5 |
+
import numpy as np
|
6 |
+
from froc_by_pranjal import file_to_bbox, calc_froc_from_dict, pretty_print_fps
|
7 |
+
import sys
|
8 |
+
from ensemble_boxes import *
|
9 |
+
import json
|
10 |
+
import pickle
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
get_file_id = lambda x: x.split('_')[1]
|
15 |
+
get_acr_cat = lambda x: '0' if x not in acr_cat else acr_cat[x]
|
16 |
+
cat_to_idx = {'a':1,'b':2,'c':3,'d':4}
|
17 |
+
|
18 |
+
|
19 |
+
def get_image_dict(dataset_paths, labels = ['mal','ben'], allowed = [], USE_ACR = False, acr_cat = None, mp_dict = None):
|
20 |
+
image_dict = dict()
|
21 |
+
if allowed == []:
|
22 |
+
allowed = [i for i in range(len(dataset_paths))]
|
23 |
+
for label in labels:
|
24 |
+
images = list(set.intersection(*map(set, [os.listdir(dset.format(label)) for dset in dataset_paths])))
|
25 |
+
for image in images:
|
26 |
+
if USE_ACR:
|
27 |
+
acr = get_acr_cat(get_file_id(image))
|
28 |
+
# print(acr, image)
|
29 |
+
key = image[:-4]
|
30 |
+
gts = []
|
31 |
+
preds = []
|
32 |
+
for i,dset in enumerate(dataset_paths):
|
33 |
+
if i not in allowed:
|
34 |
+
continue
|
35 |
+
if USE_ACR:
|
36 |
+
if dset.find('AIIMS_C')!=-1:
|
37 |
+
if acr == '0': continue
|
38 |
+
if dset.find(f'AIIMS_C{cat_to_idx[acr]}') == -1:
|
39 |
+
continue
|
40 |
+
# Now choose dset to be the acr category one
|
41 |
+
dset = dset.replace('/test',f'/test_{acr}')
|
42 |
+
# print('ds',dset)
|
43 |
+
pred_file = join(dset.format(label), key+'.txt')
|
44 |
+
gt_file = join(os.path.split(dset.format(label))[0],'gt', key+'.txt')
|
45 |
+
if label == 'mal':
|
46 |
+
gts.append(file_to_bbox(gt_file))
|
47 |
+
else:
|
48 |
+
gts.append([])
|
49 |
+
|
50 |
+
# TODO: Note this
|
51 |
+
flag = False
|
52 |
+
for mp in mp_dict:
|
53 |
+
if dataset_paths[i].find(mp) != -1:
|
54 |
+
preds.append(mp_dict[mp](file_to_bbox(pred_file)))
|
55 |
+
flag = True
|
56 |
+
break
|
57 |
+
if not flag:
|
58 |
+
preds.append(file_to_bbox(pred_file))
|
59 |
+
|
60 |
+
# Ensure all gts are same
|
61 |
+
gt = gts[0]
|
62 |
+
for g in gts[1:]:
|
63 |
+
assert g == gt
|
64 |
+
gt = g
|
65 |
+
|
66 |
+
# Flatten Preds
|
67 |
+
preds = [np.array(p) for p in preds]
|
68 |
+
preds = [np.array([[0.,0.,0.,0.,0.]]) if pred.shape==(0,) else pred for pred in preds]
|
69 |
+
preds = [np.vstack((p, np.zeros((100 - len(p), 5)))) for p in preds]
|
70 |
+
image_dict[key] = dict()
|
71 |
+
image_dict[key]['gt'] = gts[0]
|
72 |
+
image_dict[key]['preds'] = preds
|
73 |
+
return image_dict
|
74 |
+
|
75 |
+
|
76 |
+
def apply_merge(image_dict, METHOD = 'wbf', weights = None, conf_type = None):
|
77 |
+
FACTOR = 5000
|
78 |
+
fusion_func = weighted_boxes_fusion if METHOD == 'wbf' else non_maximum_weighted
|
79 |
+
for key in image_dict:
|
80 |
+
preds = np.array(image_dict[key]['preds'])
|
81 |
+
if len(preds) != 0:
|
82 |
+
boxes_list = [pred[:,1:]/FACTOR for pred in preds]
|
83 |
+
scores_list = [pred[:,0] for pred in preds]
|
84 |
+
labels = [[0. for _ in range(len(p))] for p in preds]
|
85 |
+
if weights is None:
|
86 |
+
weights = [1 for _ in range(len(preds))]
|
87 |
+
if METHOD == 'wbf' and conf_type is not None:
|
88 |
+
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5, conf_type = conf_type)
|
89 |
+
else:
|
90 |
+
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5,)
|
91 |
+
preds_t = [[scores[i],FACTOR*boxes[i][0],FACTOR*boxes[i][1],FACTOR*boxes[i][2],FACTOR*boxes[i][3]] for i in range(len(boxes))]
|
92 |
+
image_dict[key]['preds'] = preds_t
|
93 |
+
return image_dict
|
94 |
+
|
95 |
+
def manipulate_preds(preds):
|
96 |
+
return preds
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def manipulate_preds_4(preds):
|
101 |
+
return preds
|
102 |
+
|
103 |
+
tot = 0
|
104 |
+
def manipulate_preds_t1(preds): #return manipulate_preds(preds)
|
105 |
+
preds = list(filter(lambda x: x[0]>0.6,preds))
|
106 |
+
|
107 |
+
return preds
|
108 |
+
|
109 |
+
def manipulate_preds_t2(preds): return manipulate_preds_t1(preds)
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
USE_ACR = False
|
114 |
+
dataset_paths = [
|
115 |
+
'MammoDatasets/AIIMS_C1/test/{0}/preds_frcnn_AIIMS_C1',
|
116 |
+
'MammoDatasets/AIIMS_C2/test/{0}/preds_frcnn_AIIMS_C2',
|
117 |
+
'MammoDatasets/AIIMS_C3/test/{0}/preds_frcnn_AIIMS_C3',
|
118 |
+
'MammoDatasets/AIIMS_C4/test/{0}/preds_frcnn_AIIMS_C4',
|
119 |
+
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_bilateral_BILATERAL',
|
120 |
+
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_frcnn_16',
|
121 |
+
]
|
122 |
+
|
123 |
+
|
124 |
+
st = int(sys.argv[1])
|
125 |
+
end = len(dataset_paths) - int(sys.argv[2])
|
126 |
+
allowed = [i for i in range(st,end)]
|
127 |
+
allowed = [0,1,2,3,4,5]
|
128 |
+
|
129 |
+
OUT_FILE = 'contrast_frcnn.txt'
|
130 |
+
if OUT_FILE is not None:
|
131 |
+
fol = os.path.split(OUT_FILE)[0]
|
132 |
+
if fol != '':
|
133 |
+
os.makedirs(fol, exist_ok=True)
|
134 |
+
|
135 |
+
acr_cat = json.load(open('aiims_categories.json','r'))
|
136 |
+
print(allowed)
|
137 |
+
|
138 |
+
mp_dict = {
|
139 |
+
'preds_frcnn_AIIMS_C3': manipulate_preds,
|
140 |
+
'preds_frcnn_AIIMS_C4': manipulate_preds_4,
|
141 |
+
'AIIMS_T2': manipulate_preds_t2,
|
142 |
+
'AIIMS_T1': manipulate_preds_t1,
|
143 |
+
}
|
144 |
+
|
145 |
+
image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = USE_ACR, acr_cat = acr_cat, mp_dict = mp_dict)
|
146 |
+
|
147 |
+
image_dict = apply_merge(image_dict, METHOD = 'nms') # or wbf
|
148 |
+
|
149 |
+
if OUT_FILE:
|
150 |
+
pickle.dump(image_dict, open(OUT_FILE.replace('.txt','.pkl'),'wb'))
|
151 |
+
senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.],save_to=OUT_FILE)
|
152 |
+
pretty_print_fps(senses, fps)
|
DenseMammogram/model_utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torchvision.transforms as T
|
3 |
+
import cv2
|
4 |
+
from tqdm import tqdm
|
5 |
+
import detection.transforms as transforms
|
6 |
+
from dataloaders import get_direction
|
7 |
+
|
8 |
+
def generate_predictions_bilateral(model,device,testpath_,cor_dict,dset='aiims',preds_folder='preds_new'):
|
9 |
+
transform = T.Compose([T.ToPILImage(),T.ToTensor()])
|
10 |
+
model.eval()
|
11 |
+
for label in ['mal','ben']:
|
12 |
+
testpath = os.path.join(testpath_,label)
|
13 |
+
# testpath = os.path.join(dataset_path,'Training', 'train',label)
|
14 |
+
testimg = os.path.join(testpath, 'images')
|
15 |
+
|
16 |
+
#preds_folder = 'preds_new'
|
17 |
+
os.makedirs(os.path.join(testpath, preds_folder),exist_ok=True)
|
18 |
+
|
19 |
+
if not os.path.exists(os.path.join(testpath,preds_folder)):
|
20 |
+
os.makedirs(os.path.join(testpath+preds_folder),exist_ok = True)
|
21 |
+
|
22 |
+
for file in tqdm(os.listdir(testimg)):
|
23 |
+
img1 = cv2.imread(os.path.join(testimg,file))
|
24 |
+
img1 = transform(img1)
|
25 |
+
# if False:
|
26 |
+
if(os.path.join(testimg,file) in cor_dict and os.path.isfile(cor_dict[os.path.join(testimg,file)])):
|
27 |
+
print('Using Bilateral')
|
28 |
+
img2 = cv2.imread(cor_dict[os.path.join(testimg,file)])
|
29 |
+
img2 = transform(img2)
|
30 |
+
if(get_direction(dset,file)==1):
|
31 |
+
img1,_ = transforms.RandomHorizontalFlip(1.0)(img1)
|
32 |
+
|
33 |
+
images = [img1.to(device),img2.to(device)]
|
34 |
+
output = model([images])[0]
|
35 |
+
img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output)
|
36 |
+
else:
|
37 |
+
img2,_ = transforms.RandomHorizontalFlip(1.0)(img2)
|
38 |
+
|
39 |
+
images = [img1.to(device),img2.to(device)]
|
40 |
+
output = model([images])[0]
|
41 |
+
else:
|
42 |
+
print('Using FRCNN')
|
43 |
+
output = model.frcnn([img1.to(device)])[0]
|
44 |
+
#output = model.frcnn([img1.to(device)])[0]
|
45 |
+
boxes = output['boxes']
|
46 |
+
scores = output['scores']
|
47 |
+
labels = output['labels']
|
48 |
+
f = open(os.path.join(testpath,preds_folder,file[:-4]+'.txt'),'w')
|
49 |
+
for i in range(len(boxes)):
|
50 |
+
box = boxes[i].detach().cpu().numpy()
|
51 |
+
#f.write('{} {} {} {} {} {}\n'.format(scores[i].item(),labels[i].item(),box[0],box[1],box[2],box[3]))
|
52 |
+
f.write('{} {} {} {} {}\n'.format(scores[i].item(),box[0],box[1],box[2],box[3]))
|
53 |
+
|
54 |
+
|
55 |
+
def generate_predictions(model,device,testpath_,preds_folder='preds_frcnn'):
|
56 |
+
transform = T.Compose([T.ToPILImage(),T.ToTensor()])
|
57 |
+
model.eval()
|
58 |
+
for label in ['mal','ben']:
|
59 |
+
testpath = os.path.join(testpath_,label)
|
60 |
+
# testpath = os.path.join(dataset_path,'Training', 'train',label)
|
61 |
+
testimg = os.path.join(testpath, 'images')
|
62 |
+
|
63 |
+
#preds_folder = 'preds_new'
|
64 |
+
os.makedirs(os.path.join(testpath, preds_folder),exist_ok=True)
|
65 |
+
|
66 |
+
if not os.path.exists(os.path.join(testpath,preds_folder)):
|
67 |
+
os.makedirs(os.path.join(testpath+preds_folder),exist_ok = True)
|
68 |
+
|
69 |
+
for file in tqdm(os.listdir(testimg)):
|
70 |
+
im = cv2.imread(os.path.join(testimg,file))
|
71 |
+
if file == 'Mass-Training_P_00444_LEFT_CC.png':
|
72 |
+
print('Test this')
|
73 |
+
continue
|
74 |
+
im = transform(im)
|
75 |
+
|
76 |
+
output = model([im.to(device)])[0]
|
77 |
+
boxes = output['boxes'] #/ FAC
|
78 |
+
scores = output['scores']
|
79 |
+
labels = output['labels']
|
80 |
+
f = open(os.path.join(testpath,preds_folder,file[:-4]+'.txt'),'w')
|
81 |
+
for i in range(len(boxes)):
|
82 |
+
box = boxes[i].detach().cpu().numpy()
|
83 |
+
f.write('{} {} {} {} {}\n'.format(scores[i].item(),box[0],box[1],box[2],box[3]))
|
DenseMammogram/models.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, OrderedDict, Tuple
|
2 |
+
import warnings
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import cv2
|
6 |
+
import os
|
7 |
+
from torch.nn.modules.conv import Conv2d
|
8 |
+
from torch.utils.data.dataset import ConcatDataset
|
9 |
+
from tqdm import tqdm
|
10 |
+
import argparse
|
11 |
+
from torch.utils.data import Dataset,DataLoader
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torchvision import models
|
15 |
+
import detection.transforms as transforms
|
16 |
+
import torchvision.transforms as T
|
17 |
+
import detection.utils as utils
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import shutil
|
20 |
+
import json
|
21 |
+
from detection.engine import train_one_epoch, evaluate
|
22 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
23 |
+
import torch.multiprocessing
|
24 |
+
import copy
|
25 |
+
from torchvision.ops import MultiScaleRoIAlign
|
26 |
+
from torchvision.models.detection.roi_heads import RoIHeads
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
# First we will create the FRCNN model
|
32 |
+
def get_FRCNN_model(num_classes=1):
|
33 |
+
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True,trainable_backbone_layers=3,min_size=1800,max_size=3600,image_std=(1.0,1.0,1.0),box_score_thresh=0.001)
|
34 |
+
# get number of input features for the classifier
|
35 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
36 |
+
# replace the pre-trained head with a new one
|
37 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)
|
38 |
+
return model
|
39 |
+
|
40 |
+
# Some utility heads for Bilateral Model
|
41 |
+
|
42 |
+
class RoIpool(nn.Module):
|
43 |
+
|
44 |
+
def __init__(self,pool):
|
45 |
+
super().__init__()
|
46 |
+
self.box_roi_pool1 = copy.deepcopy(pool)
|
47 |
+
self.box_roi_pool2 = copy.deepcopy(pool)
|
48 |
+
|
49 |
+
|
50 |
+
def forward(self,features,proposals,image_shapes):
|
51 |
+
x = self.box_roi_pool1(features[0],proposals,image_shapes)
|
52 |
+
y = self.box_roi_pool2(features[1],proposals,image_shapes)
|
53 |
+
z = torch.cat((x,y),dim=1)
|
54 |
+
return z
|
55 |
+
|
56 |
+
class TwoMLPHead(nn.Module):
|
57 |
+
"""
|
58 |
+
Standard heads for FPN-based models
|
59 |
+
Args:
|
60 |
+
in_channels (int): number of input channels
|
61 |
+
representation_size (int): size of the intermediate representation
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, in_channels=None, representation_size=None):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
self.fc6 = nn.Linear(in_channels, representation_size)
|
68 |
+
self.fc7 = nn.Linear(representation_size, representation_size)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = x.flatten(start_dim=1)
|
72 |
+
|
73 |
+
x = F.relu(self.fc6(x))
|
74 |
+
x = F.relu(self.fc7(x))
|
75 |
+
return x
|
76 |
+
|
77 |
+
# Next the bilateral model
|
78 |
+
|
79 |
+
class Bilateral_model(nn.Module):
|
80 |
+
|
81 |
+
def __init__(self,frcnn_model):
|
82 |
+
super().__init__()
|
83 |
+
self.frcnn = frcnn_model
|
84 |
+
self.transform = copy.deepcopy(frcnn_model.transform)
|
85 |
+
self.backbone1 = copy.deepcopy(frcnn_model.backbone)
|
86 |
+
self.backbone2 = copy.deepcopy(frcnn_model.backbone)
|
87 |
+
self.rpn = copy.deepcopy(frcnn_model.rpn)
|
88 |
+
for param in self.rpn.parameters():
|
89 |
+
param.requires_grad = False
|
90 |
+
for param in self.backbone1.parameters():
|
91 |
+
param.requires_grad = False
|
92 |
+
for param in self.backbone2.parameters():
|
93 |
+
param.requires_grad = False
|
94 |
+
box_roi_pool = RoIpool(frcnn_model.roi_heads.box_roi_pool)
|
95 |
+
box_head = TwoMLPHead(512*7*7,1024)
|
96 |
+
box_predictor = copy.deepcopy(frcnn_model.roi_heads.box_predictor)
|
97 |
+
box_score_thresh=0.001
|
98 |
+
box_nms_thresh=0.5
|
99 |
+
box_detections_per_img=100
|
100 |
+
box_fg_iou_thresh=0.5
|
101 |
+
box_bg_iou_thresh=0.5
|
102 |
+
box_batch_size_per_image=512
|
103 |
+
box_positive_fraction=0.25
|
104 |
+
bbox_reg_weights=None
|
105 |
+
self.roi_heads = RoIHeads(
|
106 |
+
# Box
|
107 |
+
box_roi_pool,
|
108 |
+
box_head,
|
109 |
+
box_predictor,
|
110 |
+
box_fg_iou_thresh,
|
111 |
+
box_bg_iou_thresh,
|
112 |
+
box_batch_size_per_image,
|
113 |
+
box_positive_fraction,
|
114 |
+
bbox_reg_weights,
|
115 |
+
box_score_thresh,
|
116 |
+
box_nms_thresh,
|
117 |
+
box_detections_per_img,
|
118 |
+
)
|
119 |
+
|
120 |
+
@torch.jit.unused
|
121 |
+
def eager_outputs(self, losses, detections):
|
122 |
+
if self.training:
|
123 |
+
return losses
|
124 |
+
|
125 |
+
return detections
|
126 |
+
|
127 |
+
|
128 |
+
def forward(self, images, targets=None):
|
129 |
+
"""
|
130 |
+
Args:
|
131 |
+
images (list[Tensor(tuples)]): images to be processed
|
132 |
+
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
|
133 |
+
Returns:
|
134 |
+
result (list[BoxList] or dict[Tensor]): the output from the model.
|
135 |
+
During training, it returns a dict[Tensor] which contains the losses.
|
136 |
+
During testing, it returns list[BoxList] contains additional fields
|
137 |
+
like `scores`, `labels` and `mask` (for Mask R-CNN models).
|
138 |
+
"""
|
139 |
+
if self.training and targets is None:
|
140 |
+
raise ValueError("In training mode, targets should be passed")
|
141 |
+
if self.training:
|
142 |
+
assert targets is not None
|
143 |
+
for target in targets:
|
144 |
+
boxes = target["boxes"]
|
145 |
+
if isinstance(boxes, torch.Tensor):
|
146 |
+
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
|
147 |
+
raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.")
|
148 |
+
else:
|
149 |
+
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
|
150 |
+
|
151 |
+
original_image_sizes: List[Tuple[int, int]] = []
|
152 |
+
for img in images:
|
153 |
+
val = img[0].shape[-2:]
|
154 |
+
assert len(val) == 2
|
155 |
+
original_image_sizes.append((val[0], val[1]))
|
156 |
+
images1 = [img[0] for img in images]
|
157 |
+
images2 = [img[1] for img in images]
|
158 |
+
targets2 = copy.deepcopy(targets)
|
159 |
+
#print(images1.shape)
|
160 |
+
#print(images2.shape)
|
161 |
+
images1, targets = self.transform(images1, targets)
|
162 |
+
images2, targets2 = self.transform(images2, targets2)
|
163 |
+
|
164 |
+
# Check for degenerate boxes
|
165 |
+
# TODO: Move this to a function
|
166 |
+
if targets is not None:
|
167 |
+
for target_idx, target in enumerate(targets):
|
168 |
+
boxes = target["boxes"]
|
169 |
+
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
|
170 |
+
if degenerate_boxes.any():
|
171 |
+
# print the first degenerate box
|
172 |
+
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
|
173 |
+
degen_bb: List[float] = boxes[bb_idx].tolist()
|
174 |
+
raise ValueError(
|
175 |
+
"All bounding boxes should have positive height and width."
|
176 |
+
f" Found invalid box {degen_bb} for target at index {target_idx}."
|
177 |
+
)
|
178 |
+
|
179 |
+
features1 = self.backbone1(images1.tensors)
|
180 |
+
features2 = self.backbone2(images2.tensors)
|
181 |
+
#print(self.backbone1.out_channels)
|
182 |
+
if isinstance(features1, torch.Tensor):
|
183 |
+
features1 = OrderedDict([("0", features1)])
|
184 |
+
if isinstance(features2, torch.Tensor):
|
185 |
+
features2 = OrderedDict([("0", features2)])
|
186 |
+
proposals, proposal_losses = self.rpn(images1, features1, targets)
|
187 |
+
features = {0:features1,1:features2}
|
188 |
+
detections, detector_losses = self.roi_heads(features, proposals, images1.image_sizes, targets)
|
189 |
+
detections = self.transform.postprocess(detections, images1.image_sizes, original_image_sizes) # type: ignore[operator]
|
190 |
+
|
191 |
+
losses = {}
|
192 |
+
losses.update(detector_losses)
|
193 |
+
losses.update(proposal_losses)
|
194 |
+
|
195 |
+
if torch.jit.is_scripting():
|
196 |
+
if not self._has_warned:
|
197 |
+
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
|
198 |
+
self._has_warned = True
|
199 |
+
return losses, detections
|
200 |
+
else:
|
201 |
+
return self.eager_outputs(losses, detections)
|
DenseMammogram/plot_froc.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
####### PARAMETERS TO ADJUST #######
|
5 |
+
|
6 |
+
# Specify the files generated from merge_nms and plot corresponding graphs
|
7 |
+
base_fol = 'normal_test'
|
8 |
+
input_files = {
|
9 |
+
f'thresh_uni.txt' : 'Thresh + Uni',
|
10 |
+
f'thresh_nouni.txt' : 'Thresh + NoUni',
|
11 |
+
}
|
12 |
+
save_file = 'uni_vs_nouni.png'
|
13 |
+
# TITLE = 'Thresh + Contrast + Bilateral vs Contrast + Bilateral FROC Comparison (Normal Test)'
|
14 |
+
TITLE = 'Uni vs NoUni FROC Comparison (Normal Test)'
|
15 |
+
|
16 |
+
SHOW = False
|
17 |
+
CLIP_FPI = 1.2
|
18 |
+
MIN_CLIP_FPI = 0.0
|
19 |
+
####################################
|
20 |
+
|
21 |
+
def plot_froc(input_files, save_file, TITLE = 'FRCNN vs BILATERAL FROC', SHOW = False, CLIP_FPI = 1.2):
|
22 |
+
for file in input_files:
|
23 |
+
lines = open(file).readlines()
|
24 |
+
x = np.array([float(line.split()[0]) for line in lines])
|
25 |
+
y = np.array([float(line.split()[1]) for line in lines])
|
26 |
+
y = y[x<CLIP_FPI]
|
27 |
+
x = x[x<CLIP_FPI]
|
28 |
+
y = y[MIN_CLIP_FPI<x]
|
29 |
+
x = x[MIN_CLIP_FPI<x]
|
30 |
+
plt.plot(x, y, label = input_files[file])
|
31 |
+
plt.legend()
|
32 |
+
|
33 |
+
plt.title(TITLE)
|
34 |
+
plt.xlabel('Average False Positive Per Image')
|
35 |
+
plt.ylabel('Sensetivity')
|
36 |
+
|
37 |
+
if SHOW:
|
38 |
+
plt.show()
|
39 |
+
plt.savefig(save_file)
|
40 |
+
plt.clf()
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
plot_froc(input_files, save_file, TITLE, SHOW, CLIP_FPI)
|
DenseMammogram/requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.2
|
2 |
+
tqdm==4.62.3
|
3 |
+
torchvision==0.11.3
|
4 |
+
scipy==1.7.3
|
5 |
+
scikit-learn==1.0.2
|
6 |
+
PyYAML==6.0
|
7 |
+
Pillow==8.4.0
|
8 |
+
pandas==1.4.0
|
9 |
+
matplotlib==3.5.1
|
10 |
+
numpy
|
11 |
+
easydict==1.9
|
DenseMammogram/train_bilateral.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
from advanced_logger import LogPriority
|
4 |
+
from dataloaders import get_bilateral_dataloaders
|
5 |
+
from models import get_FRCNN_model, Bilateral_model
|
6 |
+
from detection.engine import evaluate_loss, train_one_epoch_simplified
|
7 |
+
|
8 |
+
def main(cfg, experimenter):
|
9 |
+
|
10 |
+
LR = cfg['LR']
|
11 |
+
WEIGHT_DECAY = cfg['WEIGHT_DECAY']
|
12 |
+
NUM_EPOCHS = cfg['NUM_EPOCHS']
|
13 |
+
BATCH_SIZE = cfg['BATCH_SIZE']
|
14 |
+
|
15 |
+
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
frcnn_model = get_FRCNN_model().to(device)
|
18 |
+
frcnn_model.load_state_dict(torch.load(cfg['FRCNN_MODEL_PATH']))
|
19 |
+
|
20 |
+
model = Bilateral_model(frcnn_model).to(device)
|
21 |
+
|
22 |
+
train_loader, val_loader = get_bilateral_dataloaders(experimenter.config, batch_size = BATCH_SIZE, data_dir = experimenter.config['DATA_DIR'])
|
23 |
+
|
24 |
+
if cfg["OPTIM"] == "SGD":
|
25 |
+
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad , model.roi_heads.parameters()),lr=LR,momentum=0.9,weight_decay=WEIGHT_DECAY)
|
26 |
+
elif cfg["OPTIM"] == "ADAM":
|
27 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = LR, weight_decay = WEIGHT_DECAY)
|
28 |
+
elif cfg["OPTIM"] == "ADAGRAD":
|
29 |
+
optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.roi_heads.parameters()), lr = LR, weight_decay = WEIGHT_DECAY)
|
30 |
+
for epoch in range(NUM_EPOCHS):
|
31 |
+
experimenter.start_epoch()
|
32 |
+
train_one_epoch_simplified(model, optimizer, train_loader, device, epoch, experimenter = experimenter,optimizer_backbone=None)
|
33 |
+
loss = evaluate_loss(model, device, val_loader, experimenter = experimenter)
|
34 |
+
experimenter.log('Validation Loss: {}'.format(loss), priority = LogPriority.MEDIUM)
|
35 |
+
|
36 |
+
experimenter.end_epoch(loss, model, device)
|
37 |
+
experimenter.save_model(model)
|
38 |
+
experimenter.generate_predictions(model, device)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
from experimenter import Experimenter
|
43 |
+
import os
|
44 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
|
45 |
+
cfg_file = 'configs/default.cfg'
|
46 |
+
experimenter = Experimenter(cfg_file)
|
47 |
+
main(experimenter.config['BILATERAL'], experimenter)
|
DenseMammogram/train_frcnn.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
from advanced_logger import LogPriority
|
4 |
+
from dataloaders import get_FRCNN_dataloaders
|
5 |
+
from models import get_FRCNN_model
|
6 |
+
from detection.engine import evaluate_loss, evaluate_simplified, train_one_epoch_simplified, evaluate_simplified
|
7 |
+
|
8 |
+
def main(cfg, experimenter):
|
9 |
+
|
10 |
+
LR = cfg['LR']
|
11 |
+
WEIGHT_DECAY = cfg['WEIGHT_DECAY']
|
12 |
+
NUM_EPOCHS = cfg['NUM_EPOCHS']
|
13 |
+
BATCH_SIZE = cfg['BATCH_SIZE']
|
14 |
+
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
model = get_FRCNN_model().to(device)
|
17 |
+
train_loader, val_loader = get_FRCNN_dataloaders(experimenter.config, batch_size=BATCH_SIZE, data_dir = experimenter.config['DATA_DIR'])
|
18 |
+
optimizer = torch.optim.SGD(model.parameters(),lr=LR,momentum=0.9,weight_decay=WEIGHT_DECAY)
|
19 |
+
|
20 |
+
for epoch in range(NUM_EPOCHS):
|
21 |
+
experimenter.start_epoch()
|
22 |
+
train_one_epoch_simplified(model, optimizer, train_loader, device, epoch, experimenter = experimenter)
|
23 |
+
evaluate_simplified(model, val_loader, device=device, experimenter = experimenter)
|
24 |
+
loss = evaluate_loss(model, device, val_loader, experimenter = experimenter)
|
25 |
+
experimenter.log('Validation Loss: {}'.format(loss), priority = LogPriority.MEDIUM)
|
26 |
+
experimenter.end_epoch(loss, model = model, device = device)
|
27 |
+
experimenter.save_model(model)
|
28 |
+
experimenter.generate_predictions(model, device)
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
from experimenter import Experimenter
|
32 |
+
cfg_file = 'configs/AIIMS_C1.cfg'
|
33 |
+
experimenter = Experimenter(cfg_file)
|
34 |
+
main(experimenter.config['FRCNN'], experimenter)
|
DenseMammogram/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import os
|
3 |
+
from os.path import join
|
4 |
+
|
5 |
+
|
6 |
+
class AverageMeter:
|
7 |
+
"""Computes and stores the average and current value"""
|
8 |
+
def __init__(self):
|
9 |
+
self.reset()
|
10 |
+
|
11 |
+
def reset(self):
|
12 |
+
self.val = 0
|
13 |
+
self.avg = 0
|
14 |
+
self.sum = 0
|
15 |
+
self.count = 0
|
16 |
+
|
17 |
+
def update(self, val, n=1):
|
18 |
+
self.val = val
|
19 |
+
self.sum += val * n
|
20 |
+
self.count += n
|
21 |
+
self.avg = self.sum / self.count
|
22 |
+
|
23 |
+
def create_backup(folders = None, files = None, backup_dir = 'experiments'):
|
24 |
+
if folders is None:
|
25 |
+
folders = ['.', 'corr_lists','detection']
|
26 |
+
if files is None:
|
27 |
+
files = ['.py', '.txt', '.json','.cfg']
|
28 |
+
|
29 |
+
for folder in folders:
|
30 |
+
if not os.path.isdir(folder):
|
31 |
+
continue
|
32 |
+
for file in os.listdir(folder):
|
33 |
+
if file.endswith(tuple(files)):
|
34 |
+
if folder != '.':
|
35 |
+
src = join(folder, file)
|
36 |
+
dest = join(backup_dir, folder, file)
|
37 |
+
else:
|
38 |
+
src = file
|
39 |
+
dest = join(backup_dir, file)
|
40 |
+
os.makedirs(os.path.split(dest)[0], exist_ok=True)
|
41 |
+
shutil.copy(src, dest)
|
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from model import predict
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
with gr.Blocks() as demo:
|
7 |
+
with gr.Column():
|
8 |
+
title = "<h1 style='margin-bottom: -10px; text-align: center'>Deep Learning for Detection of iso-dense, obscure masses in mammographically dense breasts</h1>"
|
9 |
+
# gr.HTML(title)
|
10 |
+
gr.Markdown(
|
11 |
+
"<h1 style='text-align: center; margin-bottom: 1rem'>"
|
12 |
+
+ title
|
13 |
+
+ "</h1>"
|
14 |
+
)
|
15 |
+
|
16 |
+
description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='' style='text-decoration:none' target='_blank'>Krithika Rangarajan<sup>*</sup>, </a> <a href='https://github.com/Pranjal2041' style='text-decoration:none' target='_blank'>Pranjal Aggarwal<sup>*</sup>, </a> <a href='' style='text-decoration:none' target='_blank'>Dhruv Kumar Gupta, </a> <a href='' style='text-decoration:none' target='_blank'>Rohan Dhanakshirur, </a> <a href='' style='text-decoration:none' target='_blank'>Akhil Baby, </a> <a href='' style='text-decoration:none' target='_blank'>Chandan Pal, </a> <a href='' style='text-decoration:none' target='_blank'>Arun Kumar Gupta, </a> <a href='' style='text-decoration:none' target='_blank'>Smriti Hari, </a> <a href='' style='text-decoration:none' target='_blank'>Subhashis Banerjee, </a> <a href='' style='text-decoration:none' target='_blank'>Chetan Arora, </a> </p>" \
|
17 |
+
+ "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://link.springer.com/article/10.1007/s00330-023-09717-7' target='_blank'>Publication</a> | <a href='https://github.com/Pranjal2041/DenseMammogram' target='_blank'>Website</a> | <a href='https://github.com/Pranjal2041/DenseMammogram' target='_blank'>Github Repo</a></p>" \
|
18 |
+
+ "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \
|
19 |
+
Deep learning suffers from some problems similar to human radiologists, such as poor sensitivity to detection of isodense, obscure masses or cancers in dense breasts. Traditional radiology teaching can be incorporated into the deep learning approach to tackle these problems in the network. Our method suggests collaborative network design, and incorporates core radiology principles resulting in SOTA results. You can use this demo to run inference by providing bilateral mammogram images. To get started, you can try one of the preset examples. \
|
20 |
+
</p>" \
|
21 |
+
+ "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'> [Note: Inference on CPU may take upto 2 minutes. On a GPU, inference time is approximately 1s.]</p>"
|
22 |
+
# gr.HTML(description)
|
23 |
+
gr.Markdown(description)
|
24 |
+
|
25 |
+
# head_html = gr.HTML('''
|
26 |
+
# <h1>
|
27 |
+
# Deep Learning for Detection of iso-dense, obscure masses in mammographically dense breasts
|
28 |
+
# </h1>
|
29 |
+
# <p style='text-align: center;'>
|
30 |
+
# Give bilateral mammograms(both left and right sides), and let our model find the cancers!
|
31 |
+
# </p>
|
32 |
+
|
33 |
+
# <p style='text-align: center;'>
|
34 |
+
# This is an official demo for our paper:
|
35 |
+
# `Deep Learning for Detection of iso-dense, obscure masses in mammographically dense breasts`.
|
36 |
+
# Check out the paper and code for details!
|
37 |
+
# </p>
|
38 |
+
# ''')
|
39 |
+
|
40 |
+
# gr.Markdown(
|
41 |
+
# """
|
42 |
+
# [![report](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/)
|
43 |
+
# """)
|
44 |
+
|
45 |
+
def generate_preds(img1, img2):
|
46 |
+
print(img1, img2)
|
47 |
+
print(img1, img2)
|
48 |
+
img_out1 = predict(img1, img2)
|
49 |
+
if img_out1.shape[1] < img_out1.shape[2]:
|
50 |
+
ratio = img_out1.shape[2] / 800
|
51 |
+
else:
|
52 |
+
ratio = img_out1.shape[1] / 800
|
53 |
+
img_out1 = cv2.resize(img_out1, (0,0), fx=1 / ratio, fy=1 / ratio)
|
54 |
+
img_out2 = predict(img2, img1, baseIsLeft = False)
|
55 |
+
if img_out2.shape[1] < img_out2.shape[2]:
|
56 |
+
ratio = img_out2.shape[2] / 800
|
57 |
+
else:
|
58 |
+
ratio = img_out2.shape[1] / 800
|
59 |
+
img_out2 = cv2.resize(img_out2, (0,0), fx= 1 / ratio, fy= 1 / ratio)
|
60 |
+
|
61 |
+
cv2.imwrite('img_out1.jpg', img_out1)
|
62 |
+
cv2.imwrite('img_out2.jpg', img_out2)
|
63 |
+
|
64 |
+
|
65 |
+
return 'img_out1.jpg', 'img_out2.jpg'
|
66 |
+
|
67 |
+
with gr.Column():
|
68 |
+
with gr.Row(variant = 'panel'):
|
69 |
+
|
70 |
+
with gr.Column(variant = 'panel'):
|
71 |
+
img1 = gr.Image(type="filepath", label="Left Image" )
|
72 |
+
img2 = gr.Image(type="filepath", label="Right Image")
|
73 |
+
# with gr.Row():
|
74 |
+
# sub_btn = gr.Button("Predict!", variant="primary")
|
75 |
+
|
76 |
+
with gr.Column(variant = 'panel'):
|
77 |
+
# img_out1 = gr.inputs.Image(type="file", label="Output Left Image")
|
78 |
+
# img_out2 = gr.inputs.Image(type="file", label="Output for Right Image")
|
79 |
+
img_out1 = gr.Image(type="filepath", label="Output for Left Image", shape = None)
|
80 |
+
img_out1.style(height=250 * 2)
|
81 |
+
|
82 |
+
with gr.Column(variant = 'panel'):
|
83 |
+
img_out2 = gr.Image(type="filepath", label="Output for Right Image", shape = None)
|
84 |
+
img_out2.style(height=250 * 2)
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
sub_btn = gr.Button("Predict!", variant="primary")
|
88 |
+
|
89 |
+
gr.Examples([[f'sample_images/img{idx}_l.jpg', f'sample_images/img{idx}_r.jpg'] for idx in range(1,6)], inputs = [img1, img2])
|
90 |
+
|
91 |
+
sub_btn.click(fn = lambda x,y: generate_preds(x,y), inputs = [img1, img2], outputs = [img_out1, img_out2])
|
92 |
+
|
93 |
+
# sub_btn.click(fn = lambda x: gr.update(visible = True), inputs = [sub_btn], outputs = [img_out1, img_out2])
|
94 |
+
|
95 |
+
# gr.Examples(
|
96 |
+
|
97 |
+
# )
|
98 |
+
|
99 |
+
|
100 |
+
# interface.render()
|
101 |
+
# Object Detection Interface
|
102 |
+
|
103 |
+
# def generate_predictions(img1, img2):
|
104 |
+
# return img1
|
105 |
+
|
106 |
+
# interface = gr.Interface(
|
107 |
+
# fn=generate_predictions,
|
108 |
+
# inputs=[gr.inputs.Image(type="pil", label="Left Image"), gr.inputs.Image(type="pil", label="Right Image")],
|
109 |
+
# outputs=[gr.outputs.Image(type="pil", label="Output Image")],
|
110 |
+
# title="Object Detection",
|
111 |
+
# description="This model is trained on DenseMammogram dataset. It can detect objects in images. Try it out!",
|
112 |
+
# allow_flagging = False
|
113 |
+
# ).launch(share = True, show_api=False)
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
demo.launch(share = True, show_api=False)
|
img_out1.jpg
ADDED
Git LFS Details
|
img_out2.jpg
ADDED
Git LFS Details
|
model.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('DenseMammogram')
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from models import get_FRCNN_model, Bilateral_model
|
7 |
+
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
frcnn_model = get_FRCNN_model().to(device)
|
10 |
+
bilat_model = Bilateral_model(frcnn_model).to(device)
|
11 |
+
|
12 |
+
FRCNN_PATH = 'pretrained_models/frcnn/frcnn_models/frcnn_model.pth'
|
13 |
+
BILAR_PATH = 'pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth'
|
14 |
+
|
15 |
+
frcnn_model.load_state_dict(torch.load(FRCNN_PATH, map_location=device))
|
16 |
+
bilat_model.load_state_dict(torch.load(BILAR_PATH, map_location=device))
|
17 |
+
|
18 |
+
import os
|
19 |
+
import torchvision.transforms as T
|
20 |
+
import cv2
|
21 |
+
from tqdm import tqdm
|
22 |
+
import detection.transforms as transforms
|
23 |
+
from dataloaders import get_direction
|
24 |
+
|
25 |
+
def predict(left_file, right_file, threshold = 0.80, baseIsLeft = True):
|
26 |
+
model = bilat_model
|
27 |
+
with torch.no_grad():
|
28 |
+
transform = T.Compose([T.ToPILImage(),T.ToTensor()])
|
29 |
+
model.eval()
|
30 |
+
# First is left, then right
|
31 |
+
img1 = cv2.imread(left_file)
|
32 |
+
img1 = transform(img1)
|
33 |
+
img2 = cv2.imread(right_file)
|
34 |
+
img2 = transform(img2)
|
35 |
+
|
36 |
+
if baseIsLeft:
|
37 |
+
img1,_ = transforms.RandomHorizontalFlip(1.0)(img1)
|
38 |
+
else:
|
39 |
+
img2,_ = transforms.RandomHorizontalFlip(1.0)(img2)
|
40 |
+
|
41 |
+
|
42 |
+
images = [img1.to(device),img2.to(device)]
|
43 |
+
output = model([images])[0]
|
44 |
+
if baseIsLeft:
|
45 |
+
img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output)
|
46 |
+
|
47 |
+
image = cv2.imread(left_file)
|
48 |
+
for b,s,l in zip(output['boxes'], output['scores'], output['labels']):
|
49 |
+
# Convert img1 tensor to numpy array
|
50 |
+
if l == 1 and s > threshold:
|
51 |
+
# Draw the bounding boxes
|
52 |
+
b = b.detach().cpu().numpy().astype(int)
|
53 |
+
# return image, b
|
54 |
+
cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 255, 0), 2)
|
55 |
+
# Print the % probability just above the box
|
56 |
+
cv2.putText(image, 'Cancer: '+str(round(round(s.item(), 2) * 100, 1)) + '%', (b[0], b[1] - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.6, (36,255,12), 6)
|
57 |
+
return image
|
pretrained_models/AIIMS_C1/frcnn_models/frcnn_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e4253bd5cda58b57e1ed38cbaadd7fa7698cbc47bcd4c795f27cf0a63a7da669
|
3 |
+
size 165725683
|
pretrained_models/AIIMS_C2/frcnn_models/frcnn_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:07ca463a86317a4db3f3ed24358ddf292701ea2a0daf67b966ac325e7d0bebae
|
3 |
+
size 165725683
|
pretrained_models/AIIMS_C3/frcnn_models/frcnn_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51ec560b1b56b9199480dee4eaaa10f45b4b96feab9397dd90f4eb05f21fd6d5
|
3 |
+
size 165725683
|
pretrained_models/AIIMS_C4/frcnn_models/frcnn_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d18b23c2a1e06a11a27ebd77e87dbb6b27d54e88d92fc55d58c64957b8cdfcfb
|
3 |
+
size 165725683
|
pretrained_models/AIIMS_T1/frcnn_models/frcnn_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d8a1d133d3629e9c717070a66e1f2f2f846daca6765097622c2fe9f95c5a513
|
3 |
+
size 165725683
|
pretrained_models/AIIMS_T2/frcnn_models/frcnn_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5db00c682eec86bb2b4e764b64feffa26774643dae780bf3cf81313f5ca6f8de
|
3 |
+
size 165725683
|
pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dce00a005fd102839f17c490b4a58191e92e99965b1ac7e323b71b0e75043d37
|
3 |
+
size 490558451
|
pretrained_models/frcnn/frcnn_models/frcnn_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e92090fd249484577db1c9e2560c82abddffd4c62203195bf8c35a32beeed4ad
|
3 |
+
size 165725683
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch==1.10.2
|
3 |
+
tqdm==4.62.3
|
4 |
+
torchvision==0.11.3
|
5 |
+
scipy==1.7.3
|
6 |
+
scikit-learn==1.0.2
|
7 |
+
PyYAML==6.0
|
8 |
+
Pillow==8.4.0
|
9 |
+
pandas==1.4.0
|
10 |
+
matplotlib==3.5.1
|
11 |
+
numpy
|
12 |
+
easydict==1.9
|