LINC-BIT commited on
Commit
b84549f
·
verified ·
1 Parent(s): 14c6c9d

Upload 1912 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. .gitignore +5 -0
  3. 1.png +0 -0
  4. 2.png +0 -0
  5. 3.png +0 -0
  6. 4.png +0 -0
  7. 5.png +0 -0
  8. Implementation.png +0 -0
  9. README.md +3 -8
  10. clip-baseline.png +0 -0
  11. clip-index.png +0 -0
  12. clip-online.png +0 -0
  13. cls_lora.png +0 -0
  14. cls_md_w_fbs_index.png +0 -0
  15. cls_md_wo_fbs.png +0 -0
  16. cls_online.png +0 -0
  17. data/README.md +94 -0
  18. data/__init__.py +14 -0
  19. data/__pycache__/__init__.cpython-38.pyc +0 -0
  20. data/__pycache__/dataloader.cpython-38.pyc +0 -0
  21. data/__pycache__/dataset.cpython-38.pyc +0 -0
  22. data/build/__init__.py +0 -0
  23. data/build/__pycache__/__init__.cpython-38.pyc +0 -0
  24. data/build/__pycache__/build.cpython-38.pyc +0 -0
  25. data/build/__pycache__/merge_alias.cpython-38.pyc +0 -0
  26. data/build/__pycache__/scenario.cpython-38.pyc +0 -0
  27. data/build/build.py +495 -0
  28. data/build/merge_alias.py +106 -0
  29. data/build/scenario.py +466 -0
  30. data/build_cl/__pycache__/build.cpython-38.pyc +0 -0
  31. data/build_cl/__pycache__/scenario.cpython-38.pyc +0 -0
  32. data/build_cl/build.py +161 -0
  33. data/build_cl/scenario.py +146 -0
  34. data/build_gen/__pycache__/build.cpython-38.pyc +0 -0
  35. data/build_gen/__pycache__/merge_alias.cpython-38.pyc +0 -0
  36. data/build_gen/__pycache__/scenario.cpython-38.pyc +0 -0
  37. data/build_gen/build.py +495 -0
  38. data/build_gen/merge_alias.py +106 -0
  39. data/build_gen/scenario.py +473 -0
  40. data/convert_all_load_to_single_load.py +56 -0
  41. data/convert_det_dataset_to_cls.py +55 -0
  42. data/convert_seg_dataset_to_cls.py +324 -0
  43. data/convert_seg_dataset_to_det.py +399 -0
  44. data/dataloader.py +131 -0
  45. data/dataset.py +43 -0
  46. data/datasets/__init__.py +12 -0
  47. data/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  48. data/datasets/__pycache__/ab_dataset.cpython-38.pyc +0 -0
  49. data/datasets/__pycache__/data_aug.cpython-38.pyc +0 -0
  50. data/datasets/__pycache__/dataset_cache.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/datasets/visual_question_answering/generate_c_image/imagenet_c/frost/frost1.png filter=lfs diff=lfs merge=lfs -text
37
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/assets/spatter.gif filter=lfs diff=lfs merge=lfs -text
38
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/assets/tilt.gif filter=lfs diff=lfs merge=lfs -text
39
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/assets/translate.gif filter=lfs diff=lfs merge=lfs -text
40
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/ImageNet-C/create_c/frost1.png filter=lfs diff=lfs merge=lfs -text
41
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/ImageNet-C/imagenet_c/imagenet_c/frost/frost1.png filter=lfs diff=lfs merge=lfs -text
42
+ new_results.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ results
2
+ logs
3
+ entry_model
4
+ __pycache__
5
+ backup_codes
1.png ADDED
2.png ADDED
3.png ADDED
4.png ADDED
5.png ADDED
Implementation.png ADDED
README.md CHANGED
@@ -1,10 +1,5 @@
1
  ---
2
- title: EdgeTA
3
- emoji: 📊
4
- colorFrom: green
5
- colorTo: red
6
  sdk: static
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: EdgeFM
 
 
 
3
  sdk: static
4
+ app_file: index.html
5
+ ---
 
 
clip-baseline.png ADDED
clip-index.png ADDED
clip-online.png ADDED
cls_lora.png ADDED
cls_md_w_fbs_index.png ADDED
cls_md_wo_fbs.png ADDED
cls_online.png ADDED
data/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## How to implement a dataset?
2
+
3
+ For example, we want to implement a image classification dataset.
4
+
5
+ 1. create a file in corresponding directory, i.e. `benchmark/data/datasets/image_classification`
6
+
7
+ 2. create a class (inherited from `benchmark.data.datasets.ab_dataset.ABDataset`), e.g. `class YourDataset(ABDataset)`
8
+
9
+ 3. register your dataset with `benchmark.data.datasets.registry.dataset_register(name, classes, classes_aliases)`, which represents the name of your dataset, the classes of your dataset, and the possible aliases of the classes. Examples refer to `benchmark/data/datasets/image_classification/cifar10.py` or other files.
10
+
11
+ Note that the order of `classes` must match the indexes. For example, `classes` of MNIST must be `['0', '1', '2', ..., '9']`, which means 0-th class is '0', 1-st class is '1', 2-nd class is '2', ...; `['1', '2', '0', ...]` is not correct because 0-th class is not '1' and 1-st class is not '2'.
12
+
13
+ How to get `classes` of a dataset? For PyTorch built-in dataset (CIFAR10, MNIST, ...) and general dataset build by `ImageFolder`, you can initialize it (e.g. `dataset = CIFAR10(...)`) and get its classes by `dataset.classes`.
14
+
15
+ ```python
16
+ # How to get classes in CIFAR10?
17
+ from torchvision.datasets import CIFAR10
18
+ dataset = CIFAR10(...)
19
+ print(dataset.classes)
20
+ # copy this output to @dataset_register(classes=<what you copied>)
21
+
22
+ # it's not recommended to dynamically get classes, e.g.:
23
+ # this works but runs slowly!
24
+ from torchvision.datasets import CIFAR10 as RawCIFAR10
25
+ dataset = RawCIFAR10(...)
26
+
27
+ @dataset_register(
28
+ name='CIFAR10',
29
+ classes=dataset.classes
30
+ )
31
+ class CIFAR10(ABDataset):
32
+ # ...
33
+ ```
34
+
35
+ For object detection dataset, you can read the annotation JSON file and find `categories` information in it.
36
+
37
+ 4. implement abstract function `create_dataset(self, root_dir: str, split: str, transform: Optional[Compose], classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]])`.
38
+
39
+ Arguments:
40
+
41
+ - `root_dir`: the location of data
42
+ - `split`: `train / val / test`
43
+ - `transform`: preprocess function in `torchvision.transforms`
44
+ - `classes`: the same value with `dataset_register.classes`
45
+ - `ignore_classes`: **classes should be discarded. You should remove images which belong to these ignore classes.**
46
+ - `idx_map`: **map the original class index to new class index. For example, `{0: 2}` means the index of 0-th class will be 2 instead of 0. You should implement this by modifying the stored labels in the original dataset. **
47
+
48
+ You should do five things in this function:
49
+
50
+ 1. if no user-defined transform is passed, you should implemented the default transform
51
+ 2. create the original dataset
52
+ 3. remove ignored classes in the original dataset if there are ignored classes
53
+ 4. map the original class index to new class index if there is index map
54
+ 5. split the original dataset to train / val / test dataset. If there's no val dataset in original dataset (e.g. DomainNetReal), you should split the original dataset to train / val / test dataset. If there's already val dataset in original dataset (e.g. CIFAR10 and ImageNet), regard the original val dataset as test dataset, and split the original train dataset into train / val dataset. Details just refer to existed files.
55
+
56
+ Example (`benchmark/data/datasets/image_classification/cifar10.py`):
57
+
58
+ ```python
59
+ @dataset_register(
60
+ name='CIFAR10',
61
+ # means in the original CIFAR10, 0-th class is airplane, 1-st class is automobile, ...
62
+ classes=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
63
+ # means 'automobile' and 'car' are the same thing actually
64
+ class_aliases=[['automobile', 'car']]
65
+ )
66
+ class CIFAR10(ABDataset):
67
+ def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
68
+ classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
69
+ # 1. if no user-defined transform is passed, you should implemented the default transform
70
+ if transform is None:
71
+ transform = cifar_like_image_train_aug() if split == 'train' else cifar_like_image_test_aug()
72
+ # 2. create the original dataset
73
+ dataset = RawCIFAR10(root_dir, split != 'test', transform=transform, download=True)
74
+
75
+ # 3. remove ignored classes in the original dataset if there are ignored classes
76
+ dataset.targets = np.asarray(dataset.targets)
77
+ if len(ignore_classes) > 0:
78
+ for ignore_class in ignore_classes:
79
+ dataset.data = dataset.data[dataset.targets != classes.index(ignore_class)]
80
+ dataset.targets = dataset.targets[dataset.targets != classes.index(ignore_class)]
81
+
82
+ # 4. map the original class index to new class index if there is index map
83
+ if idx_map is not None:
84
+ for ti, t in enumerate(dataset.targets):
85
+ dataset.targets[ti] = idx_map[t]
86
+
87
+ # 5. split the original dataset to train / val / test dataset.
88
+ # there is not val dataset in CIFAR10 dataset, so we split the val dataset from the train dataset.
89
+ if split != 'test':
90
+ dataset = train_val_split(dataset, split)
91
+ return dataset
92
+ ```
93
+
94
+ After implementing a new dataset, you can create a test file in `example` and load the dataset by `benchmark.data.dataset.get_dataset()`. Try using this dataset to ensure it works. (Example: `example/1.py`)
data/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import get_dataset
2
+ from .build.build import build_scenario_manually_v2 as build_scenario
3
+ from .dataloader import build_dataloader
4
+ from .build.scenario import IndexReturnedDataset, MergedDataset
5
+ from .datasets.ab_dataset import ABDataset
6
+ from .build.scenario import Scenario
7
+
8
+ from .build_cl.build import build_cl_scenario
9
+ from .build_cl.scenario import Scenario as CLScenario
10
+
11
+ from .build_gen.build import build_scenario_manually_v2 as build_gen_scenario
12
+ from .build_gen.scenario import Scenario as GenScenario
13
+
14
+ from .datasets.dataset_split import split_dataset
data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (722 Bytes). View file
 
data/__pycache__/dataloader.cpython-38.pyc ADDED
Binary file (3.48 kB). View file
 
data/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (1.24 kB). View file
 
data/build/__init__.py ADDED
File without changes
data/build/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (130 Bytes). View file
 
data/build/__pycache__/build.cpython-38.pyc ADDED
Binary file (9.06 kB). View file
 
data/build/__pycache__/merge_alias.cpython-38.pyc ADDED
Binary file (2.5 kB). View file
 
data/build/__pycache__/scenario.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
data/build/build.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Type, Union
2
+ from ..datasets.ab_dataset import ABDataset
3
+ # from benchmark.data.visualize import visualize_classes_in_object_detection
4
+ # from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform
5
+ from ..dataset import get_dataset
6
+ import copy
7
+ from torchvision.transforms import Compose
8
+
9
+ from .merge_alias import merge_the_same_meaning_classes
10
+ from ..datasets.registery import static_dataset_registery
11
+
12
+
13
+ # some legacy aliases of variables:
14
+ # ignore_classes == discarded classes
15
+ # private_classes == unknown classes in partial / open-set / universal DA
16
+
17
+
18
+ def _merge_the_same_meaning_classes(classes_info_of_all_datasets):
19
+ final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes(classes_info_of_all_datasets)
20
+ return final_classes_of_all_datasets, rename_map
21
+
22
+
23
+ def _find_ignore_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
24
+ thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
25
+
26
+ from functools import reduce
27
+ a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
28
+
29
+ if set(a_classes) == set(b_classes):
30
+ # a is equal to b, normal
31
+ # 1. no ignore classes; 2. match class idx
32
+ a_ignore_classes, b_ignore_classes = [], []
33
+
34
+ elif set(a_classes) > set(b_classes):
35
+ # a contains b, partial
36
+ a_ignore_classes, b_ignore_classes = [], []
37
+ if thres == 3 or thres == 1: # ignore extra classes in a
38
+ a_ignore_classes = set(a_classes) - set(b_classes)
39
+
40
+ elif set(a_classes) < set(b_classes):
41
+ # a is contained by b, open set
42
+ a_ignore_classes, b_ignore_classes = [], []
43
+ if thres == 3 or thres == 2: # ignore extra classes in b
44
+ b_ignore_classes = set(b_classes) - set(a_classes)
45
+
46
+ elif len(set(a_classes) & set(b_classes)) > 0:
47
+ a_ignore_classes, b_ignore_classes = [], []
48
+ if thres == 3:
49
+ a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
50
+ b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
51
+ elif thres == 2:
52
+ b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
53
+ elif thres == 1:
54
+ a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
55
+
56
+ else:
57
+ return None # a has no intersection with b, none
58
+
59
+ as_ignore_classes = [list(set(a_classes) & set(a_ignore_classes)) for a_classes in as_classes]
60
+
61
+ return as_ignore_classes, list(b_ignore_classes)
62
+
63
+
64
+ def _find_private_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
65
+ thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
66
+
67
+ from functools import reduce
68
+ a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
69
+
70
+ if set(a_classes) == set(b_classes):
71
+ # a is equal to b, normal
72
+ # 1. no ignore classes; 2. match class idx
73
+ a_private_classes, b_private_classes = [], []
74
+
75
+ elif set(a_classes) > set(b_classes):
76
+ # a contains b, partial
77
+ a_private_classes, b_private_classes = [], []
78
+ # if thres == 2 or thres == 0: # ignore extra classes in a
79
+ # a_private_classes = set(a_classes) - set(b_classes)
80
+ # if thres == 0: # ignore extra classes in a
81
+ # a_private_classes = set(a_classes) - set(b_classes)
82
+
83
+ elif set(a_classes) < set(b_classes):
84
+ # a is contained by b, open set
85
+ a_private_classes, b_private_classes = [], []
86
+ if thres == 1 or thres == 0: # ignore extra classes in b
87
+ b_private_classes = set(b_classes) - set(a_classes)
88
+
89
+ elif len(set(a_classes) & set(b_classes)) > 0:
90
+ a_private_classes, b_private_classes = [], []
91
+ if thres == 0:
92
+ # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
93
+
94
+ b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
95
+ elif thres == 1:
96
+ b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
97
+ elif thres == 2:
98
+ # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
99
+ pass
100
+
101
+ else:
102
+ return None # a has no intersection with b, none
103
+
104
+ return list(b_private_classes)
105
+
106
+
107
+ class _ABDatasetMetaInfo:
108
+ def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type):
109
+ self.name = name
110
+ self.classes = classes
111
+ self.class_aliases = class_aliases
112
+ self.shift_type = shift_type
113
+ self.task_type = task_type
114
+ self.object_type = object_type
115
+
116
+
117
+ def _get_dist_shift_type_when_source_a_to_target_b(a: _ABDatasetMetaInfo, b: _ABDatasetMetaInfo):
118
+ if b.shift_type is None:
119
+ return 'Dataset Shifts'
120
+
121
+ if a.name in b.shift_type.keys():
122
+ return b.shift_type[a.name]
123
+
124
+ mid_dataset_name = list(b.shift_type.keys())[0]
125
+ mid_dataset_meta_info = _ABDatasetMetaInfo(mid_dataset_name, *static_dataset_registery[mid_dataset_name][1:])
126
+
127
+ return _get_dist_shift_type_when_source_a_to_target_b(a, mid_dataset_meta_info) + ' + ' + list(b.shift_type.values())[0]
128
+
129
+
130
+ def _handle_all_datasets_v2(source_datasets: List[_ABDatasetMetaInfo], target_datasets: List[_ABDatasetMetaInfo], da_mode):
131
+
132
+ # 1. merge the same meaning classes
133
+ classes_info_of_all_datasets = {
134
+ d.name: (d.classes, d.class_aliases)
135
+ for d in source_datasets + target_datasets
136
+ }
137
+ final_classes_of_all_datasets, rename_map = _merge_the_same_meaning_classes(classes_info_of_all_datasets)
138
+ all_datasets_classes = copy.deepcopy(final_classes_of_all_datasets)
139
+
140
+ # print(all_datasets_known_classes)
141
+
142
+ # 2. find ignored classes according to DA mode
143
+ # source_datasets_ignore_classes, target_datasets_ignore_classes = {d.name: [] for d in source_datasets}, \
144
+ # {d.name: [] for d in target_datasets}
145
+ # source_datasets_private_classes, target_datasets_private_classes = {d.name: [] for d in source_datasets}, \
146
+ # {d.name: [] for d in target_datasets}
147
+ target_source_relationship_map = {td.name: {} for td in target_datasets}
148
+ # source_target_relationship_map = {sd.name: [] for sd in source_datasets}
149
+
150
+ # 1. construct target_source_relationship_map
151
+ for sd in source_datasets:#sd和td使列表中每一个元素(类)的实例
152
+ for td in target_datasets:
153
+ sc = all_datasets_classes[sd.name]
154
+ tc = all_datasets_classes[td.name]
155
+
156
+ if len(set(sc) & set(tc)) == 0:#只保留有相似类别的源域和目标域
157
+ continue
158
+
159
+ target_source_relationship_map[td.name][sd.name] = _get_dist_shift_type_when_source_a_to_target_b(sd, td)
160
+
161
+ # print(target_source_relationship_map)
162
+ # exit()
163
+
164
+ source_datasets_ignore_classes = {}
165
+ for td_name, v1 in target_source_relationship_map.items():
166
+ for sd_name, v2 in v1.items():
167
+ source_datasets_ignore_classes[sd_name + '|' + td_name] = []
168
+ target_datasets_ignore_classes = {d.name: [] for d in target_datasets}
169
+ target_datasets_private_classes = {d.name: [] for d in target_datasets}
170
+ # 保证对于每个目标域上的DA都符合给定的label shift
171
+ # 所以不同目标域就算对应同一个源域,该源域也可能不相同
172
+
173
+ for td_name, v1 in target_source_relationship_map.items():
174
+ sd_names = list(v1.keys())
175
+
176
+ sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
177
+ td_classes = all_datasets_classes[td_name]
178
+ ss_ignore_classes, t_ignore_classes = _find_ignore_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)#根据DA方式不同产生ignore_classes
179
+ t_private_classes = _find_private_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)
180
+
181
+ for sd_name, s_ignore_classes in zip(sd_names, ss_ignore_classes):
182
+ source_datasets_ignore_classes[sd_name + '|' + td_name] = s_ignore_classes
183
+ target_datasets_ignore_classes[td_name] = t_ignore_classes
184
+ target_datasets_private_classes[td_name] = t_private_classes
185
+
186
+ source_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in source_datasets_ignore_classes.items()}
187
+ target_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_ignore_classes.items()}
188
+ target_datasets_private_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_private_classes.items()}
189
+
190
+ # for k, v in source_datasets_ignore_classes.items():
191
+ # print(k, len(v))
192
+ # print()
193
+ # for k, v in target_datasets_ignore_classes.items():
194
+ # print(k, len(v))
195
+ # print()
196
+ # for k, v in target_datasets_private_classes.items():
197
+ # print(k, len(v))
198
+ # print()
199
+
200
+ # print(source_datasets_private_classes, target_datasets_private_classes)
201
+ # 3. reparse classes idx
202
+ # 3.1. agg all used classes
203
+ # all_used_classes = []
204
+ # all_datasets_private_class_idx_map = {}
205
+
206
+ # source_datasets_classes_idx_map = {}
207
+ # for td_name, v1 in target_source_relationship_map.items():
208
+ # for sd_name, v2 in v1.items():
209
+ # source_datasets_classes_idx_map[sd_name + '|' + td_name] = []
210
+ # target_datasets_classes_idx_map = {}
211
+
212
+ global_idx = 0
213
+ all_used_classes_idx_map = {}
214
+ # all_datasets_known_classes = {d: [] for d in final_classes_of_all_datasets.keys()}
215
+ for dataset_name, classes in all_datasets_classes.items():
216
+ if dataset_name not in target_datasets_ignore_classes.keys():
217
+ ignore_classes = [0] * 100000
218
+ for sn, sic in source_datasets_ignore_classes.items():
219
+ if sn.startswith(dataset_name):
220
+ if len(sic) < len(ignore_classes):
221
+ ignore_classes = sic
222
+ else:
223
+ ignore_classes = target_datasets_ignore_classes[dataset_name]
224
+ private_classes = [] \
225
+ if dataset_name not in target_datasets_ignore_classes.keys() else target_datasets_private_classes[dataset_name]
226
+
227
+ for c in classes:
228
+ if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c not in private_classes:
229
+ all_used_classes_idx_map[c] = global_idx
230
+ global_idx += 1
231
+
232
+ # print(all_used_classes_idx_map)
233
+
234
+ # dataset_private_class_idx_offset = 0
235
+ target_private_class_idx = global_idx
236
+ target_datasets_private_class_idx = {d: None for d in target_datasets_private_classes.keys()}
237
+
238
+ for dataset_name, classes in final_classes_of_all_datasets.items():
239
+ if dataset_name not in target_datasets_private_classes.keys():
240
+ continue
241
+
242
+ # ignore_classes = target_datasets_ignore_classes[dataset_name]
243
+ private_classes = target_datasets_private_classes[dataset_name]
244
+ # private_classes = [] \
245
+ # if dataset_name in source_datasets_private_classes.keys() else target_datasets_private_classes[dataset_name]
246
+ # for c in classes:
247
+ # if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c in private_classes:
248
+ # all_used_classes_idx_map[c] = global_idx + dataset_private_class_idx_offset
249
+
250
+ if len(private_classes) > 0:
251
+ # all_datasets_private_class_idx[dataset_name] = global_idx + dataset_private_class_idx_offset
252
+ # dataset_private_class_idx_offset += 1
253
+ # if dataset_name in source_datasets_private_classes.keys():
254
+ # if source_private_class_idx is None:
255
+ # source_private_class_idx = global_idx if target_private_class_idx is None else target_private_class_idx + 1
256
+ # all_datasets_private_class_idx[dataset_name] = source_private_class_idx
257
+ # else:
258
+ # if target_private_class_idx is None:
259
+ # target_private_class_idx = global_idx if source_private_class_idx is None else source_private_class_idx + 1
260
+ # all_datasets_private_class_idx[dataset_name] = target_private_class_idx
261
+ target_datasets_private_class_idx[dataset_name] = target_private_class_idx
262
+ target_private_class_idx += 1
263
+
264
+
265
+ # all_used_classes = sorted(set(all_used_classes), key=all_used_classes.index)
266
+ # all_used_classes_idx_map = {c: i for i, c in enumerate(all_used_classes)}
267
+
268
+ # print('rename_map', rename_map)
269
+
270
+ # 3.2 raw_class -> rename_map[raw_classes] -> all_used_classes_idx_map
271
+ all_datasets_e2e_idx_map = {}
272
+ all_datasets_e2e_class_to_idx_map = {}
273
+
274
+ for td_name, v1 in target_source_relationship_map.items():
275
+ sd_names = list(v1.keys())
276
+ sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
277
+ td_classes = all_datasets_classes[td_name]
278
+
279
+ for sd_name, sd_classes in zip(sd_names, sds_classes):
280
+ cur_e2e_idx_map = {}
281
+ cur_e2e_class_to_idx_map = {}
282
+
283
+ for raw_ci, raw_c in enumerate(sd_classes):
284
+ renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
285
+
286
+ ignore_classes = source_datasets_ignore_classes[sd_name + '|' + td_name]
287
+ if renamed_c in ignore_classes:
288
+ continue
289
+
290
+ idx = all_used_classes_idx_map[renamed_c]
291
+
292
+ cur_e2e_idx_map[raw_ci] = idx
293
+ cur_e2e_class_to_idx_map[raw_c] = idx
294
+
295
+ all_datasets_e2e_idx_map[sd_name + '|' + td_name] = cur_e2e_idx_map
296
+ all_datasets_e2e_class_to_idx_map[sd_name + '|' + td_name] = cur_e2e_class_to_idx_map
297
+ cur_e2e_idx_map = {}
298
+ cur_e2e_class_to_idx_map = {}
299
+ for raw_ci, raw_c in enumerate(td_classes):
300
+ renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
301
+
302
+ ignore_classes = target_datasets_ignore_classes[td_name]
303
+ if renamed_c in ignore_classes:
304
+ continue
305
+
306
+ if renamed_c in target_datasets_private_classes[td_name]:
307
+ idx = target_datasets_private_class_idx[td_name]
308
+ else:
309
+ idx = all_used_classes_idx_map[renamed_c]
310
+
311
+ cur_e2e_idx_map[raw_ci] = idx
312
+ cur_e2e_class_to_idx_map[raw_c] = idx
313
+
314
+ all_datasets_e2e_idx_map[td_name] = cur_e2e_idx_map
315
+ all_datasets_e2e_class_to_idx_map[td_name] = cur_e2e_class_to_idx_map
316
+
317
+ all_datasets_ignore_classes = {**source_datasets_ignore_classes, **target_datasets_ignore_classes}
318
+ # all_datasets_private_classes = {**source_datasets_private_classes, **target_datasets_private_classes}
319
+
320
+ classes_idx_set = []
321
+ for d, m in all_datasets_e2e_class_to_idx_map.items():
322
+ classes_idx_set += list(m.values())
323
+ classes_idx_set = set(classes_idx_set)
324
+ num_classes = len(classes_idx_set)
325
+
326
+ return all_datasets_ignore_classes, target_datasets_private_classes, \
327
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
328
+ target_source_relationship_map, rename_map, num_classes
329
+
330
+
331
+ def _build_scenario_info_v2(
332
+ source_datasets_name: List[str],
333
+ target_datasets_order: List[str],
334
+ da_mode: str
335
+ ):
336
+ assert da_mode in ['close_set', 'partial', 'open_set', 'universal']
337
+ da_mode = {'close_set': 'da', 'partial': 'partial_da', 'open_set': 'open_set_da', 'universal': 'universal_da'}[da_mode]
338
+
339
+ source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]#获知对应的名字和对应属性,要添加数据集时,直接register就行
340
+ target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
341
+
342
+ all_datasets_ignore_classes, target_datasets_private_classes, \
343
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
344
+ target_source_relationship_map, rename_map, num_classes \
345
+ = _handle_all_datasets_v2(source_datasets_meta_info, target_datasets_meta_info, da_mode)
346
+
347
+ return all_datasets_ignore_classes, target_datasets_private_classes, \
348
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
349
+ target_source_relationship_map, rename_map, num_classes
350
+
351
+
352
+ def build_scenario_manually_v2(
353
+ source_datasets_name: List[str],
354
+ target_datasets_order: List[str],
355
+ da_mode: str,
356
+ data_dirs: Dict[str, str],
357
+ # transforms: Optional[Dict[str, Compose]] = None
358
+ ):
359
+ configs = copy.deepcopy(locals())#返回当前局部变量
360
+
361
+ source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]
362
+ target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
363
+
364
+ all_datasets_ignore_classes, target_datasets_private_classes, \
365
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
366
+ target_source_relationship_map, rename_map, num_classes \
367
+ = _build_scenario_info_v2(source_datasets_name, target_datasets_order, da_mode)
368
+ # from rich.console import Console
369
+ # console = Console(width=10000)
370
+
371
+ # def print_obj(_o):
372
+ # # import pprint
373
+ # # s = pprint.pformat(_o, width=140, compact=True)
374
+ # console.print(_o)
375
+
376
+ # console.print('configs:', style='bold red')
377
+ # print_obj(configs)
378
+ # console.print('renamed classes:', style='bold red')
379
+ # print_obj(rename_map)
380
+ # console.print('discarded classes:', style='bold red')
381
+ # print_obj(all_datasets_ignore_classes)
382
+ # console.print('unknown classes:', style='bold red')
383
+ # print_obj(target_datasets_private_classes)
384
+ # console.print('class to index map:', style='bold red')
385
+ # print_obj(all_datasets_e2e_class_to_idx_map)
386
+ # console.print('index map:', style='bold red')
387
+ # print_obj(all_datasets_e2e_idx_map)
388
+ # console = Console()
389
+ # # console.print('class distribution:', style='bold red')
390
+ # # class_dist = {
391
+ # # k: {
392
+ # # '#known classes': len(all_datasets_known_classes[k]),
393
+ # # '#unknown classes': len(all_datasets_private_classes[k]),
394
+ # # '#discarded classes': len(all_datasets_ignore_classes[k])
395
+ # # } for k in all_datasets_ignore_classes.keys()
396
+ # # }
397
+ # # print_obj(class_dist)
398
+ # console.print('corresponding sources of each target:', style='bold red')
399
+ # print_obj(target_source_relationship_map)
400
+
401
+ # return
402
+
403
+ # res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
404
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
405
+ # for split in ['train', 'val', 'test']}
406
+ # for d in source_datasets_name}
407
+ # res_target_datasets_map = {d: {'train': get_num_limited_dataset(get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
408
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]),
409
+ # num_samples_in_each_target_domain),
410
+ # 'test': get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
411
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
412
+ # }
413
+ # for d in list(set(target_datasets_order))}
414
+
415
+ # res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
416
+ # getattr(transforms, d.split('|')[0], None),
417
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
418
+ # for split in ['train', 'val', 'test']}
419
+ # for d in all_datasets_ignore_classes.keys() if d.split('|')[0] in source_datasets_name}
420
+
421
+ # from functools import reduce
422
+ # res_offline_train_source_datasets_map = {}
423
+ # res_offline_train_source_datasets_map_names = {}
424
+
425
+ # for d in source_datasets_name:
426
+ # source_dataset_with_max_num_classes = None
427
+
428
+ # for ed_name, ed in res_source_datasets_map.items():
429
+ # if not ed_name.startswith(d):
430
+ # continue
431
+
432
+ # if source_dataset_with_max_num_classes is None:
433
+ # source_dataset_with_max_num_classes = ed
434
+ # res_offline_train_source_datasets_map_names[d] = ed_name
435
+
436
+ # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
437
+ # source_dataset_with_max_num_classes = ed
438
+ # res_offline_train_source_datasets_map_names[d] = ed_name
439
+
440
+ # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
441
+
442
+ # res_target_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
443
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
444
+ # for split in ['train', 'val', 'test']}
445
+ # for d in list(set(target_datasets_order))}
446
+
447
+ from .scenario import Scenario, DatasetMetaInfo
448
+
449
+ # test_scenario = Scenario(
450
+ # config=configs,
451
+ # offline_source_datasets_meta_info={
452
+ # d: DatasetMetaInfo(d,
453
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[res_offline_train_source_datasets_map_names[d]].items()},
454
+ # None)
455
+ # for d in source_datasets_name
456
+ # },
457
+ # offline_source_datasets={d: res_offline_train_source_datasets_map[d] for d in source_datasets_name},
458
+
459
+ # online_datasets_meta_info=[
460
+ # (
461
+ # {sd + '|' + d: DatasetMetaInfo(d,
462
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[sd + '|' + d].items()},
463
+ # None)
464
+ # for sd in target_source_relationship_map[d].keys()},
465
+ # DatasetMetaInfo(d,
466
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[d].items() if k not in target_datasets_private_classes[d]},
467
+ # target_datasets_private_class_idx[d])
468
+ # )
469
+ # for d in target_datasets_order
470
+ # ],
471
+ # online_datasets={**res_source_datasets_map, **res_target_datasets_map},
472
+ # target_domains_order=target_datasets_order,
473
+ # target_source_map=target_source_relationship_map,
474
+ # num_classes=num_classes
475
+ # )
476
+ import os
477
+ os.environ['_ZQL_NUMC'] = str(num_classes)
478
+
479
+ test_scenario = Scenario(config=configs, all_datasets_ignore_classes_map=all_datasets_ignore_classes,
480
+ all_datasets_idx_map=all_datasets_e2e_idx_map,
481
+ target_domains_order=target_datasets_order,
482
+ target_source_map=target_source_relationship_map,
483
+ all_datasets_e2e_class_to_idx_map=all_datasets_e2e_class_to_idx_map,
484
+ num_classes=num_classes)
485
+
486
+
487
+ return test_scenario
488
+
489
+
490
+ if __name__ == '__main__':
491
+ test_scenario = build_scenario_manually_v2(['CIFAR10', 'SVHN'],
492
+ ['STL10', 'MNIST', 'STL10', 'USPS', 'MNIST', 'STL10'],
493
+ 'close_set')
494
+ print(test_scenario.num_classes)
495
+
data/build/merge_alias.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import L
2
+ from typing import Dict, List
3
+ from collections import Counter
4
+
5
+
6
+ def grouping(bondlist):
7
+ # reference: https://blog.csdn.net/YnagShanwen/article/details/111344386
8
+ groups = []
9
+ break1 = False
10
+ while bondlist:
11
+ pair1 = bondlist.pop(0)
12
+ a = 11111
13
+ b = 10000
14
+ while b != a:
15
+ a = b
16
+ for atomid in pair1:
17
+ for i,pair2 in enumerate(bondlist):
18
+ if atomid in pair2:
19
+ pair1 = pair1 + pair2
20
+ bondlist.pop(i)
21
+ if not bondlist:
22
+ break1 = True
23
+ break
24
+ if break1:
25
+ break
26
+ b = len(pair1)
27
+ groups.append(pair1)
28
+ return groups
29
+
30
+
31
+ def build_semantic_class_info(classes: List[str], aliases: List[List[str]]):
32
+ res = []
33
+ for c in classes:
34
+ # print(res)
35
+ if len(aliases) == 0:
36
+ res += [[c]]
37
+ else:
38
+ find_alias = False
39
+ for alias in aliases:
40
+ if c in alias:
41
+ res += [alias]
42
+ find_alias = True
43
+ break
44
+ if not find_alias:
45
+ res += [[c]]
46
+ # print(classes, res)
47
+ return res
48
+
49
+
50
+ def merge_the_same_meaning_classes(classes_info_of_all_datasets):
51
+ # print(classes_info_of_all_datasets)
52
+
53
+ semantic_classes_of_all_datasets = []
54
+ all_aliases = []
55
+ for classes, aliases in classes_info_of_all_datasets.values():
56
+ all_aliases += aliases
57
+ for classes, aliases in classes_info_of_all_datasets.values():
58
+ semantic_classes_of_all_datasets += build_semantic_class_info(classes, all_aliases)
59
+
60
+ # print(semantic_classes_of_all_datasets)
61
+
62
+ grouped_classes_of_all_datasets = grouping(semantic_classes_of_all_datasets)#匹配过后的数据
63
+
64
+ # print(grouped_classes_of_all_datasets)
65
+
66
+ # final_grouped_classes_of_all_datasets = [Counter(c).most_common()[0][0] for c in grouped_classes_of_all_datasets]
67
+ # use most common class name; if the same common, use shortest class name!
68
+ final_grouped_classes_of_all_datasets = []
69
+ for c in grouped_classes_of_all_datasets:
70
+ counter = Counter(c).most_common()
71
+ max_times = counter[0][1]
72
+ candidate_class_names = []
73
+ for item, times in counter:
74
+ if times < max_times:
75
+ break
76
+ candidate_class_names += [item]
77
+ candidate_class_names.sort(key=lambda x: len(x))
78
+
79
+ final_grouped_classes_of_all_datasets += [candidate_class_names[0]]
80
+ res = {}
81
+ res_map = {d: {} for d in classes_info_of_all_datasets.keys()}
82
+
83
+ for dataset_name, (classes, _) in classes_info_of_all_datasets.items():
84
+ final_classes = []
85
+ for c in classes:
86
+ for grouped_names, final_name in zip(grouped_classes_of_all_datasets, final_grouped_classes_of_all_datasets):
87
+ if c in grouped_names:
88
+ final_classes += [final_name]
89
+ if final_name != c:
90
+ res_map[dataset_name][c] = final_name
91
+ break
92
+ res[dataset_name] = sorted(set(final_classes), key=final_classes.index)
93
+ return res, res_map
94
+
95
+
96
+ if __name__ == '__main__':
97
+ cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
98
+ cifar10_aliases = [['automobile', 'car']]
99
+ stl10_classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
100
+
101
+ final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes({
102
+ 'CIFAR10': (cifar10_classes, cifar10_aliases),
103
+ 'STL10': (stl10_classes, [])
104
+ })
105
+
106
+ print(final_classes_of_all_datasets, rename_map)
data/build/scenario.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from functools import reduce
3
+ from typing import Dict, List, Tuple
4
+ import numpy as np
5
+ import copy
6
+ from utils.common.log import logger
7
+ from ..datasets.ab_dataset import ABDataset
8
+ from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
9
+ from data import get_dataset
10
+
11
+
12
+ class DatasetMetaInfo:
13
+ def __init__(self, name,
14
+ known_classes_name_idx_map, unknown_class_idx):
15
+
16
+ assert unknown_class_idx not in known_classes_name_idx_map.keys()
17
+
18
+ self.name = name
19
+ self.unknown_class_idx = unknown_class_idx
20
+ self.known_classes_name_idx_map = known_classes_name_idx_map
21
+
22
+ @property
23
+ def num_classes(self):
24
+ return len(self.known_classes_idx) + 1
25
+
26
+
27
+ class MergedDataset:
28
+ def __init__(self, datasets: List[ABDataset]):
29
+ self.datasets = datasets
30
+ self.datasets_len = [len(i) for i in self.datasets]
31
+ logger.info(f'create MergedDataset: len of datasets {self.datasets_len}')
32
+ self.datasets_cum_len = np.cumsum(self.datasets_len)
33
+
34
+ def __getitem__(self, idx):
35
+ for i, cum_len in enumerate(self.datasets_cum_len):
36
+ if idx < cum_len:
37
+ return self.datasets[i][idx - sum(self.datasets_len[0: i])]
38
+
39
+ def __len__(self):
40
+ return sum(self.datasets_len)
41
+
42
+
43
+ class IndexReturnedDataset:
44
+ def __init__(self, dataset: ABDataset):
45
+ self.dataset = dataset
46
+
47
+ def __getitem__(self, idx):
48
+ res = self.dataset[idx]
49
+
50
+ if isinstance(res, (tuple, list)):
51
+ return (*res, idx)
52
+ else:
53
+ return res, idx
54
+
55
+ def __len__(self):
56
+ return len(self.dataset)
57
+
58
+
59
+ # class Scenario:
60
+ # def __init__(self, config,
61
+ # source_datasets_meta_info: Dict[str, DatasetMetaInfo], target_datasets_meta_info: Dict[str, DatasetMetaInfo],
62
+ # target_source_map: Dict[str, Dict[str, str]],
63
+ # target_domains_order: List[str],
64
+ # source_datasets: Dict[str, Dict[str, ABDataset]], target_datasets: Dict[str, Dict[str, ABDataset]]):
65
+
66
+ # self.__config = config
67
+ # self.__source_datasets_meta_info = source_datasets_meta_info
68
+ # self.__target_datasets_meta_info = target_datasets_meta_info
69
+ # self.__target_source_map = target_source_map
70
+ # self.__target_domains_order = target_domains_order
71
+ # self.__source_datasets = source_datasets
72
+ # self.__target_datasets = target_datasets
73
+
74
+ # # 1. basic
75
+ # def get_config(self):
76
+ # return copy.deepcopy(self.__config)
77
+
78
+ # def get_task_type(self):
79
+ # return list(self.__source_datasets.values())[0]['train'].task_type
80
+
81
+ # def get_num_classes(self):
82
+ # known_classes_idx = []
83
+ # unknown_classes_idx = []
84
+ # for v in self.__source_datasets_meta_info.values():
85
+ # known_classes_idx += list(v.known_classes_name_idx_map.values())
86
+ # unknown_classes_idx += [v.unknown_class_idx]
87
+ # for v in self.__target_datasets_meta_info.values():
88
+ # known_classes_idx += list(v.known_classes_name_idx_map.values())
89
+ # unknown_classes_idx += [v.unknown_class_idx]
90
+ # unknown_classes_idx = [i for i in unknown_classes_idx if i is not None]
91
+ # # print(known_classes_idx, unknown_classes_idx)
92
+ # res = len(set(known_classes_idx)), len(set(unknown_classes_idx)), len(set(known_classes_idx + unknown_classes_idx))
93
+ # # print(res)
94
+ # assert res[0] + res[1] == res[2]
95
+ # return res
96
+
97
+ # def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
98
+ # if infinite:
99
+ # dataloader = InfiniteDataLoader(
100
+ # dataset, None, batch_size, num_workers=num_workers)
101
+ # else:
102
+ # dataloader = FastDataLoader(
103
+ # dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
104
+
105
+ # return dataloader
106
+
107
+ # def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
108
+ # from ..data.datasets.dataset_split import _SplitDataset
109
+ # dataset.dataset = _SplitDataset(dataset.dataset, indexes)
110
+ # return dataset
111
+
112
+ # def build_index_returned_dataset(self, dataset: ABDataset):
113
+ # return IndexReturnedDataset(dataset)
114
+
115
+ # # 2. source
116
+ # def get_source_datasets_meta_info(self):
117
+ # return self.__source_datasets_meta_info
118
+
119
+ # def get_source_datasets_name(self):
120
+ # return list(self.__source_datasets.keys())
121
+
122
+ # def get_merged_source_dataset(self, split):
123
+ # source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
124
+ # return MergedDataset(list(source_train_datasets.values()))
125
+
126
+ # def get_source_datasets(self, split):
127
+ # source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
128
+ # return source_train_datasets
129
+
130
+ # # 3. target **domain**
131
+ # # (do we need such API `get_ith_target_domain()`?)
132
+ # def get_target_domains_meta_info(self):
133
+ # return self.__source_datasets_meta_info
134
+
135
+ # def get_target_domains_order(self):
136
+ # return self.__target_domains_order
137
+
138
+ # def get_corr_source_datasets_name_of_target_domain(self, target_domain_name):
139
+ # return self.__target_source_map[target_domain_name]
140
+
141
+ # def get_limited_target_train_dataset(self):
142
+ # if len(self.__target_domains_order) > 1:
143
+ # raise RuntimeError('this API is only for pass-in scenario in user-defined online DA algorithm')
144
+ # return list(self.__target_datasets.values())[0]['train']
145
+
146
+ # def get_target_domains_iterator(self, split):
147
+ # for target_domain_index, target_domain_name in enumerate(self.__target_domains_order):
148
+ # target_dataset = self.__target_datasets[target_domain_name]
149
+ # target_domain_meta_info = self.__target_datasets_meta_info[target_domain_name]
150
+
151
+ # yield target_domain_index, target_domain_name, target_dataset[split], target_domain_meta_info
152
+
153
+ # # 4. permission management
154
+ # def get_sub_scenario(self, source_datasets_name, source_splits, target_domains_order, target_splits):
155
+ # def get_split(dataset, splits):
156
+ # res = {}
157
+ # for s, d in dataset.items():
158
+ # if s in splits:
159
+ # res[s] = d
160
+ # return res
161
+
162
+ # return Scenario(
163
+ # config=self.__config,
164
+ # source_datasets_meta_info={k: v for k, v in self.__source_datasets_meta_info.items() if k in source_datasets_name},
165
+ # target_datasets_meta_info={k: v for k, v in self.__target_datasets_meta_info.items() if k in target_domains_order},
166
+ # target_source_map={k: v for k, v in self.__target_source_map.items() if k in target_domains_order},
167
+ # target_domains_order=target_domains_order,
168
+ # source_datasets={k: get_split(v, source_splits) for k, v in self.__source_datasets.items() if k in source_datasets_name},
169
+ # target_datasets={k: get_split(v, target_splits) for k, v in self.__target_datasets.items() if k in target_domains_order}
170
+ # )
171
+
172
+ # def get_only_source_sub_scenario_for_exp_tracker(self):
173
+ # return self.get_sub_scenario(self.get_source_datasets_name(), ['train', 'val', 'test'], [], [])
174
+
175
+ # def get_only_source_sub_scenario_for_alg(self):
176
+ # return self.get_sub_scenario(self.get_source_datasets_name(), ['train'], [], [])
177
+
178
+ # def get_one_da_sub_scenario_for_alg(self, target_domain_name):
179
+ # return self.get_sub_scenario(self.get_corr_source_datasets_name_of_target_domain(target_domain_name),
180
+ # ['train', 'val'], [target_domain_name], ['train'])
181
+
182
+
183
+ # class Scenario:
184
+ # def __init__(self, config,
185
+
186
+ # offline_source_datasets_meta_info: Dict[str, DatasetMetaInfo],
187
+ # offline_source_datasets: Dict[str, ABDataset],
188
+
189
+ # online_datasets_meta_info: List[Tuple[Dict[str, DatasetMetaInfo], DatasetMetaInfo]],
190
+ # online_datasets: Dict[str, ABDataset],
191
+ # target_domains_order: List[str],
192
+ # target_source_map: Dict[str, Dict[str, str]],
193
+
194
+ # num_classes: int):
195
+
196
+ # self.config = config
197
+
198
+ # self.offline_source_datasets_meta_info = offline_source_datasets_meta_info
199
+ # self.offline_source_datasets = offline_source_datasets
200
+
201
+ # self.online_datasets_meta_info = online_datasets_meta_info
202
+ # self.online_datasets = online_datasets
203
+
204
+ # self.target_domains_order = target_domains_order
205
+ # self.target_source_map = target_source_map
206
+
207
+ # self.num_classes = num_classes
208
+
209
+ # def get_offline_source_datasets(self, split):
210
+ # return {n: d[split] for n, d in self.offline_source_datasets.items()}
211
+
212
+ # def get_offline_source_merged_dataset(self, split):
213
+ # return MergedDataset([d[split] for d in self.offline_source_datasets.values()])
214
+
215
+ # def get_online_current_corresponding_source_datasets(self, domain_index, split):
216
+ # cur_target_domain_name = self.target_domains_order[domain_index]
217
+ # cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
218
+ # cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
219
+ # return cur_source_datasets
220
+
221
+ # def get_online_current_corresponding_merged_source_dataset(self, domain_index, split):
222
+ # cur_target_domain_name = self.target_domains_order[domain_index]
223
+ # cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
224
+ # cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
225
+ # return MergedDataset([d for d in cur_source_datasets.values()])
226
+
227
+ # def get_online_current_target_dataset(self, domain_index, split):
228
+ # cur_target_domain_name = self.target_domains_order[domain_index]
229
+ # return self.online_datasets[cur_target_domain_name][split]
230
+
231
+ # def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int,
232
+ # infinite: bool, shuffle_when_finite: bool, to_iterator: bool):
233
+ # if infinite:
234
+ # dataloader = InfiniteDataLoader(
235
+ # dataset, None, batch_size, num_workers=num_workers)
236
+ # else:
237
+ # dataloader = FastDataLoader(
238
+ # dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
239
+
240
+ # if to_iterator:
241
+ # dataloader = iter(dataloader)
242
+
243
+ # return dataloader
244
+
245
+ # def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
246
+ # from data.datasets.dataset_split import _SplitDataset
247
+ # dataset.dataset = _SplitDataset(dataset.dataset, indexes)
248
+ # return dataset
249
+
250
+ # def build_index_returned_dataset(self, dataset: ABDataset):
251
+ # return IndexReturnedDataset(dataset)
252
+
253
+ # def get_config(self):
254
+ # return copy.deepcopy(self.config)
255
+
256
+ # def get_task_type(self):
257
+ # return list(self.online_datasets.values())[0]['train'].task_type
258
+
259
+ # def get_num_classes(self):
260
+ # return self.num_classes
261
+
262
+
263
+ class Scenario:
264
+ def __init__(self, config, all_datasets_ignore_classes_map, all_datasets_idx_map, target_domains_order, target_source_map,
265
+ all_datasets_e2e_class_to_idx_map,
266
+ num_classes):
267
+ self.config = config
268
+ self.all_datasets_ignore_classes_map = all_datasets_ignore_classes_map
269
+ self.all_datasets_idx_map = all_datasets_idx_map
270
+ self.target_domains_order = target_domains_order
271
+ self.target_source_map = target_source_map
272
+ self.all_datasets_e2e_class_to_idx_map = all_datasets_e2e_class_to_idx_map
273
+ self.num_classes = num_classes
274
+ self.cur_domain_index = 0
275
+
276
+ logger.info(f'[scenario build] # classes: {num_classes}')
277
+ logger.debug(f'[scenario build] idx map: {all_datasets_idx_map}')
278
+
279
+ def to_json(self):
280
+ return dict(
281
+ config=self.config, all_datasets_ignore_classes_map=self.all_datasets_ignore_classes_map,
282
+ all_datasets_idx_map=self.all_datasets_idx_map, target_domains_order=self.target_domains_order,
283
+ target_source_map=self.target_source_map,
284
+ all_datasets_e2e_class_to_idx_map=self.all_datasets_e2e_class_to_idx_map,
285
+ num_classes=self.num_classes
286
+ )
287
+
288
+ def __str__(self):
289
+ return f'Scenario({self.to_json()})'
290
+
291
+ def get_offline_datasets(self, transform=None):
292
+ # make source datasets which contains all unioned classes
293
+ res_offline_train_source_datasets_map = {}
294
+
295
+ from .. import get_dataset
296
+ data_dirs = self.config['data_dirs']
297
+
298
+ source_datasets_name = self.config['source_datasets_name']
299
+ res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
300
+ transform,
301
+ self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
302
+ for split in ['train', 'val', 'test']}
303
+ for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name}
304
+
305
+ for source_dataset_name in self.config['source_datasets_name']:
306
+ source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
307
+
308
+ # how to merge idx map?
309
+ # 35 79 97
310
+ idx_maps = [d['train'].idx_map for d in source_datasets]
311
+ ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
312
+
313
+ union_idx_map = {}
314
+ for idx_map in idx_maps:
315
+ for k, v in idx_map.items():
316
+ if k not in union_idx_map:
317
+ union_idx_map[k] = v
318
+ else:
319
+ assert union_idx_map[k] == v
320
+
321
+ union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
322
+ assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
323
+
324
+ logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
325
+
326
+ d = source_dataset_name
327
+ res_offline_train_source_datasets_map[d] = {split: get_dataset(d, data_dirs[d], split,
328
+ transform,
329
+ union_ignore_classes, union_idx_map)
330
+ for split in ['train', 'val', 'test']}
331
+
332
+ return res_offline_train_source_datasets_map
333
+
334
+ def get_offline_datasets_args(self):
335
+ # make source datasets which contains all unioned classes
336
+ res_offline_train_source_datasets_map = {}
337
+
338
+ from .. import get_dataset
339
+ data_dirs = self.config['data_dirs']
340
+
341
+ source_datasets_name = self.config['source_datasets_name']
342
+ res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
343
+ None,
344
+ self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
345
+ for split in ['train', 'val', 'test']}
346
+ for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name}
347
+
348
+ for source_dataset_name in self.config['source_datasets_name']:
349
+ source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
350
+
351
+ # how to merge idx map?
352
+ # 35 79 97
353
+ idx_maps = [d['train'].idx_map for d in source_datasets]
354
+ ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
355
+
356
+ union_idx_map = {}
357
+ for idx_map in idx_maps:
358
+ for k, v in idx_map.items():
359
+ if k not in union_idx_map:
360
+ union_idx_map[k] = v
361
+ else:
362
+ assert union_idx_map[k] == v
363
+
364
+ union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
365
+ assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
366
+
367
+ logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
368
+
369
+ d = source_dataset_name
370
+ res_offline_train_source_datasets_map[d] = {split: dict(d, data_dirs[d], split,
371
+ None,
372
+ union_ignore_classes, union_idx_map)
373
+ for split in ['train', 'val', 'test']}
374
+
375
+ return res_offline_train_source_datasets_map
376
+
377
+ # for d in source_datasets_name:
378
+ # source_dataset_with_max_num_classes = None
379
+
380
+ # for ed_name, ed in res_source_datasets_map.items():
381
+ # if not ed_name.startswith(d):
382
+ # continue
383
+
384
+ # if source_dataset_with_max_num_classes is None:
385
+ # source_dataset_with_max_num_classes = ed
386
+ # res_offline_train_source_datasets_map_names[d] = ed_name
387
+
388
+ # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
389
+ # source_dataset_with_max_num_classes = ed
390
+ # res_offline_train_source_datasets_map_names[d] = ed_name
391
+
392
+ # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
393
+
394
+ # return res_offline_train_source_datasets_map
395
+
396
+ def get_online_ith_domain_datasets_args_for_inference(self, domain_index):
397
+ target_dataset_name = self.target_domains_order[domain_index]
398
+ # dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
399
+
400
+ if 'MM-CityscapesDet' in self.target_domains_order or 'CityscapesDet' in self.target_domains_order or 'BaiduPersonDet' in self.target_domains_order:
401
+ logger.info(f'use val split for inference test (only Det workload)')
402
+ split = 'test'
403
+ else:
404
+ split = 'train'
405
+
406
+ return dict(dataset_name=target_dataset_name,
407
+ root_dir=self.config['data_dirs'][target_dataset_name],
408
+ split=split,
409
+ transform=None,
410
+ ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
411
+ idx_map=self.all_datasets_idx_map[target_dataset_name])
412
+
413
+ def get_online_ith_domain_datasets_args_for_training(self, domain_index):
414
+ target_dataset_name = self.target_domains_order[domain_index]
415
+ source_datasets_name = list(self.target_source_map[target_dataset_name].keys())
416
+
417
+ res = {}
418
+ # dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
419
+ res[target_dataset_name] = {split: dict(dataset_name=target_dataset_name,
420
+ root_dir=self.config['data_dirs'][target_dataset_name],
421
+ split=split,
422
+ transform=None,
423
+ ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
424
+ idx_map=self.all_datasets_idx_map[target_dataset_name]) for split in ['train', 'val']}
425
+ for d in source_datasets_name:
426
+ res[d] = {split: dict(dataset_name=d,
427
+ root_dir=self.config['data_dirs'][d],
428
+ split=split,
429
+ transform=None,
430
+ ignore_classes=self.all_datasets_ignore_classes_map[d + '|' + target_dataset_name],
431
+ idx_map=self.all_datasets_idx_map[d + '|' + target_dataset_name]) for split in ['train', 'val']}
432
+
433
+ return res
434
+
435
+ def get_online_cur_domain_datasets_args_for_inference(self):
436
+ return self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
437
+
438
+ def get_online_cur_domain_datasets_args_for_training(self):
439
+ return self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
440
+
441
+ def get_online_cur_domain_datasets_for_training(self, transform=None):
442
+ res = {}
443
+ datasets_args = self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
444
+ for dataset_name, dataset_args in datasets_args.items():
445
+ res[dataset_name] = {}
446
+ for split, args in dataset_args.items():
447
+ if transform is not None:
448
+ args['transform'] = transform
449
+ dataset = get_dataset(**args)
450
+ res[dataset_name][split] = dataset
451
+ return res
452
+
453
+ def get_online_cur_domain_datasets_for_inference(self, transform=None):
454
+ datasets_args = self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
455
+ if transform is not None:
456
+ datasets_args['transform'] = transform
457
+ return get_dataset(**datasets_args)
458
+
459
+ def get_online_cur_domain_samples_for_training(self, num_samples, transform=None, collate_fn=None):
460
+ dataset = self.get_online_cur_domain_datasets_for_training(transform=transform)
461
+ dataset = dataset[self.target_domains_order[self.cur_domain_index]]['train']
462
+ return next(iter(build_dataloader(dataset, num_samples, 0, True, None, collate_fn=collate_fn)))[0]
463
+
464
+ def next_domain(self):
465
+ self.cur_domain_index += 1
466
+
data/build_cl/__pycache__/build.cpython-38.pyc ADDED
Binary file (4.31 kB). View file
 
data/build_cl/__pycache__/scenario.cpython-38.pyc ADDED
Binary file (5.42 kB). View file
 
data/build_cl/build.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Type, Union
2
+ from ..datasets.ab_dataset import ABDataset
3
+ # from benchmark.data.visualize import visualize_classes_in_object_detection
4
+ # from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform
5
+ from ..dataset import get_dataset
6
+ import copy
7
+ from torchvision.transforms import Compose
8
+ from ..datasets.registery import static_dataset_registery
9
+ from ..build.scenario import Scenario as DAScenario
10
+ from copy import deepcopy
11
+ from utils.common.log import logger
12
+ import random
13
+ from .scenario import _ABDatasetMetaInfo, Scenario
14
+
15
+
16
+ def _check(source_datasets_meta_info: List[_ABDatasetMetaInfo], target_datasets_meta_info: List[_ABDatasetMetaInfo]):
17
+ # requirements for simplity
18
+ # 1. no same class in source datasets
19
+
20
+ source_datasets_class = [i.classes for i in source_datasets_meta_info]
21
+ for ci1, c1 in enumerate(source_datasets_class):
22
+ for ci2, c2 in enumerate(source_datasets_class):
23
+ if ci1 == ci2:
24
+ continue
25
+
26
+ c1_name = source_datasets_meta_info[ci1].name
27
+ c2_name = source_datasets_meta_info[ci2].name
28
+ intersection = set(c1).intersection(set(c2))
29
+ assert len(intersection) == 0, f'{c1_name} has intersection with {c2_name}: {intersection}'
30
+
31
+
32
+ def build_cl_scenario(
33
+ da_scenario: DAScenario,
34
+ target_datasets_name: List[str],
35
+ num_classes_per_task: int,
36
+ max_num_tasks: int,
37
+ data_dirs,
38
+ sanity_check=False
39
+ ):
40
+ config = deepcopy(locals())
41
+
42
+ source_datasets_idx_map = {}
43
+ source_class_idx_max = 0
44
+
45
+ for sd in da_scenario.config['source_datasets_name']:
46
+ da_scenario_idx_map = None
47
+ for k, v in da_scenario.all_datasets_idx_map.items():
48
+ if k.startswith(sd):
49
+ da_scenario_idx_map = v
50
+ break
51
+
52
+ source_datasets_idx_map[sd] = da_scenario_idx_map
53
+ source_class_idx_max = max(source_class_idx_max, max(list(da_scenario_idx_map.values())))
54
+
55
+
56
+ target_class_idx_start = source_class_idx_max + 1
57
+
58
+ target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:], None, None) for d in target_datasets_name]
59
+
60
+ task_datasets_seq = []
61
+
62
+ num_tasks_per_dataset = {}
63
+
64
+ for td_info_i, td_info in enumerate(target_datasets_meta_info):
65
+
66
+ if td_info_i >= 1:
67
+ for _td_info_i, _td_info in enumerate(target_datasets_meta_info[0: td_info_i]):
68
+ if _td_info.name == td_info.name:
69
+ # print(111)
70
+ # class_idx_offset = sum([len(t.classes) for t in target_datasets_meta_info[0: td_info_i]])
71
+ print(len(task_datasets_seq))
72
+
73
+ task_index_offset = sum([v if __i < _td_info_i else 0 for __i, v in enumerate(num_tasks_per_dataset.values())])
74
+
75
+ task_datasets_seq += task_datasets_seq[task_index_offset: task_index_offset + num_tasks_per_dataset[_td_info_i]]
76
+ print(len(task_datasets_seq))
77
+ break
78
+ continue
79
+
80
+ td_classes = td_info.classes
81
+ num_tasks_per_dataset[td_info_i] = 0
82
+
83
+ for ci in range(0, len(td_classes), num_classes_per_task):
84
+ task_i = ci // num_classes_per_task
85
+ task_datasets_seq += [_ABDatasetMetaInfo(
86
+ f'{td_info.name}|task-{task_i}|ci-{ci}-{ci + num_classes_per_task - 1}',
87
+ td_classes[ci: ci + num_classes_per_task],
88
+ td_info.task_type,
89
+ td_info.object_type,
90
+ td_info.class_aliases,
91
+ td_info.shift_type,
92
+
93
+ td_classes[:ci] + td_classes[ci + num_classes_per_task: ],
94
+ {cii: cii + target_class_idx_start for cii in range(ci, ci + num_classes_per_task)}
95
+ )]
96
+ num_tasks_per_dataset[td_info_i] += 1
97
+
98
+ if ci + num_classes_per_task < len(td_classes) - 1:
99
+ task_datasets_seq += [_ABDatasetMetaInfo(
100
+ f'{td_info.name}-task-{task_i + 1}|ci-{ci}-{ci + num_classes_per_task - 1}',
101
+ td_classes[ci: len(td_classes)],
102
+ td_info.task_type,
103
+ td_info.object_type,
104
+ td_info.class_aliases,
105
+ td_info.shift_type,
106
+
107
+ td_classes[:ci],
108
+ {cii: cii + target_class_idx_start for cii in range(ci, len(td_classes))}
109
+ )]
110
+ num_tasks_per_dataset[td_info_i] += 1
111
+
112
+ target_class_idx_start += len(td_classes)
113
+
114
+ if len(task_datasets_seq) < max_num_tasks:
115
+ print(len(task_datasets_seq), max_num_tasks)
116
+ raise RuntimeError()
117
+
118
+ task_datasets_seq = task_datasets_seq[0: max_num_tasks]
119
+ target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq])
120
+
121
+ scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
122
+
123
+ if sanity_check:
124
+ selected_tasks_index = []
125
+ for task_index, _ in enumerate(scenario.target_tasks_order):
126
+ cur_datasets = scenario.get_cur_task_train_datasets()
127
+
128
+ if len(cur_datasets) < 300:
129
+ # empty_tasks_index += [task_index]
130
+ # while True:
131
+ # replaced_task_index = random.randint(0, task_index - 1) # ensure no random
132
+ replaced_task_index = task_index // 2
133
+ assert replaced_task_index != task_index
134
+ while replaced_task_index in selected_tasks_index:
135
+ replaced_task_index += 1
136
+
137
+ task_datasets_seq[task_index] = deepcopy(task_datasets_seq[replaced_task_index])
138
+ selected_tasks_index += [replaced_task_index]
139
+
140
+ logger.warning(f'replace {task_index}-th task with {replaced_task_index}-th task')
141
+
142
+ # print(task_index, [t.name for t in task_datasets_seq])
143
+
144
+ scenario.next_task()
145
+
146
+ # print([t.name for t in task_datasets_seq])
147
+
148
+ if len(selected_tasks_index) > 0:
149
+ target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq])
150
+ scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
151
+
152
+ for task_index, _ in enumerate(scenario.target_tasks_order):
153
+ cur_datasets = scenario.get_cur_task_train_datasets()
154
+ logger.info(f'task {task_index}, len {len(cur_datasets)}')
155
+ assert len(cur_datasets) > 0
156
+
157
+ scenario.next_task()
158
+
159
+ scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
160
+
161
+ return scenario
data/build_cl/scenario.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from functools import reduce
3
+ from typing import Dict, List, Tuple
4
+ import numpy as np
5
+ import copy
6
+ from utils.common.log import logger
7
+ from ..datasets.ab_dataset import ABDataset
8
+ from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
9
+ from data import get_dataset, MergedDataset, Scenario as DAScenario
10
+
11
+
12
+ class _ABDatasetMetaInfo:
13
+ def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map):
14
+ self.name = name
15
+ self.classes = classes
16
+ self.class_aliases = class_aliases
17
+ self.shift_type = shift_type
18
+ self.task_type = task_type
19
+ self.object_type = object_type
20
+
21
+ self.ignore_classes = ignore_classes
22
+ self.idx_map = idx_map
23
+
24
+ def __repr__(self) -> str:
25
+ return f'({self.name}, {self.classes}, {self.idx_map})'
26
+
27
+
28
+ class Scenario:
29
+ def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs):
30
+ self.config = config
31
+ self.target_datasets_info = target_datasets_info
32
+ self.num_classes = num_classes
33
+ self.cur_task_index = 0
34
+ self.num_source_classes = num_source_classes
35
+ self.cur_class_offset = num_source_classes
36
+ self.data_dirs = data_dirs
37
+
38
+ self.target_tasks_order = [i.name for i in self.target_datasets_info]
39
+ self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info])
40
+
41
+ logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, '
42
+ f'# classes per task: {config["num_classes_per_task"]}')
43
+
44
+ def to_json(self):
45
+ config = copy.deepcopy(self.config)
46
+ config['da_scenario'] = config['da_scenario'].to_json()
47
+ target_datasets_info = [str(i) for i in self.target_datasets_info]
48
+ return dict(
49
+ config=config, target_datasets_info=target_datasets_info,
50
+ num_classes=self.num_classes
51
+ )
52
+
53
+ def __str__(self):
54
+ return f'Scenario({self.to_json()})'
55
+
56
+ def get_cur_class_offset(self):
57
+ return self.cur_class_offset
58
+
59
+ def get_cur_num_class(self):
60
+ return len(self.target_datasets_info[self.cur_task_index].classes)
61
+
62
+ def get_nc_per_task(self):
63
+ return len(self.target_datasets_info[0].classes)
64
+
65
+ def next_task(self):
66
+ self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes)
67
+ self.cur_task_index += 1
68
+
69
+ print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}')
70
+
71
+ def get_cur_task_datasets(self):
72
+ dataset_info = self.target_datasets_info[self.cur_task_index]
73
+ dataset_name = dataset_info.name.split('|')[0]
74
+ # print()
75
+
76
+ # source_datasets_info = []
77
+
78
+ res ={ **{split: get_dataset(dataset_name=dataset_name,
79
+ root_dir=self.data_dirs[dataset_name],
80
+ split=split,
81
+ transform=None,
82
+ ignore_classes=dataset_info.ignore_classes,
83
+ idx_map=dataset_info.idx_map) for split in ['train']},
84
+
85
+ **{split: MergedDataset([get_dataset(dataset_name=dataset_name,
86
+ root_dir=self.data_dirs[dataset_name],
87
+ split=split,
88
+ transform=None,
89
+ ignore_classes=di.ignore_classes,
90
+ idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]])
91
+ for split in ['val', 'test']}
92
+ }
93
+
94
+ # if len(res['train']) < 200 or len(res['val']) < 200 or len(res['test']) < 200:
95
+ # return None
96
+
97
+
98
+ if len(res['train']) < 1000:
99
+ res['train'] = MergedDataset([res['train']] * 5)
100
+ logger.info('aug train dataset')
101
+ if len(res['val']) < 1000:
102
+ res['val'] = MergedDataset(res['val'].datasets * 5)
103
+ logger.info('aug val dataset')
104
+ if len(res['test']) < 1000:
105
+ res['test'] = MergedDataset(res['test'].datasets * 5)
106
+ logger.info('aug test dataset')
107
+ # da_scenario: DAScenario = self.config['da_scenario']
108
+ # offline_datasets = da_scenario.get_offline_datasets()
109
+
110
+ for k, v in res.items():
111
+ logger.info(f'{k} dataset: {len(v)}')
112
+
113
+ # new_val_datasets = [
114
+ # *[d['val'] for d in offline_datasets.values()],
115
+ # res['val']
116
+ # ]
117
+ # res['val'] = MergedDataset(new_val_datasets)
118
+
119
+ # new_test_datasets = [
120
+ # *[d['test'] for d in offline_datasets.values()],
121
+ # res['test']
122
+ # ]
123
+ # res['test'] = MergedDataset(new_test_datasets)
124
+
125
+ return res
126
+
127
+ def get_cur_task_train_datasets(self):
128
+ dataset_info = self.target_datasets_info[self.cur_task_index]
129
+ dataset_name = dataset_info.name.split('|')[0]
130
+ # print()
131
+
132
+ # source_datasets_info = []
133
+
134
+ res = get_dataset(dataset_name=dataset_name,
135
+ root_dir=self.data_dirs[dataset_name],
136
+ split='train',
137
+ transform=None,
138
+ ignore_classes=dataset_info.ignore_classes,
139
+ idx_map=dataset_info.idx_map)
140
+
141
+ return res
142
+
143
+ def get_online_cur_task_samples_for_training(self, num_samples):
144
+ dataset = self.get_cur_task_datasets()
145
+ dataset = dataset['train']
146
+ return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0]
data/build_gen/__pycache__/build.cpython-38.pyc ADDED
Binary file (9.07 kB). View file
 
data/build_gen/__pycache__/merge_alias.cpython-38.pyc ADDED
Binary file (2.5 kB). View file
 
data/build_gen/__pycache__/scenario.cpython-38.pyc ADDED
Binary file (9.65 kB). View file
 
data/build_gen/build.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Type, Union
2
+ from ..datasets.ab_dataset import ABDataset
3
+ # from benchmark.data.visualize import visualize_classes_in_object_detection
4
+ # from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform
5
+ from ..dataset import get_dataset
6
+ import copy
7
+ from torchvision.transforms import Compose
8
+
9
+ from .merge_alias import merge_the_same_meaning_classes
10
+ from ..datasets.registery import static_dataset_registery
11
+
12
+
13
+ # some legacy aliases of variables:
14
+ # ignore_classes == discarded classes
15
+ # private_classes == unknown classes in partial / open-set / universal DA
16
+
17
+
18
+ def _merge_the_same_meaning_classes(classes_info_of_all_datasets):
19
+ final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes(classes_info_of_all_datasets)
20
+ return final_classes_of_all_datasets, rename_map
21
+
22
+
23
+ def _find_ignore_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
24
+ thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
25
+
26
+ from functools import reduce
27
+ a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
28
+
29
+ if set(a_classes) == set(b_classes):
30
+ # a is equal to b, normal
31
+ # 1. no ignore classes; 2. match class idx
32
+ a_ignore_classes, b_ignore_classes = [], []
33
+
34
+ elif set(a_classes) > set(b_classes):
35
+ # a contains b, partial
36
+ a_ignore_classes, b_ignore_classes = [], []
37
+ if thres == 3 or thres == 1: # ignore extra classes in a
38
+ a_ignore_classes = set(a_classes) - set(b_classes)
39
+
40
+ elif set(a_classes) < set(b_classes):
41
+ # a is contained by b, open set
42
+ a_ignore_classes, b_ignore_classes = [], []
43
+ if thres == 3 or thres == 2: # ignore extra classes in b
44
+ b_ignore_classes = set(b_classes) - set(a_classes)
45
+
46
+ elif len(set(a_classes) & set(b_classes)) > 0:
47
+ a_ignore_classes, b_ignore_classes = [], []
48
+ if thres == 3:
49
+ a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
50
+ b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
51
+ elif thres == 2:
52
+ b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
53
+ elif thres == 1:
54
+ a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
55
+
56
+ else:
57
+ return None # a has no intersection with b, none
58
+
59
+ as_ignore_classes = [list(set(a_classes) & set(a_ignore_classes)) for a_classes in as_classes]
60
+
61
+ return as_ignore_classes, list(b_ignore_classes)
62
+
63
+
64
+ def _find_private_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
65
+ thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
66
+
67
+ from functools import reduce
68
+ a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
69
+
70
+ if set(a_classes) == set(b_classes):
71
+ # a is equal to b, normal
72
+ # 1. no ignore classes; 2. match class idx
73
+ a_private_classes, b_private_classes = [], []
74
+
75
+ elif set(a_classes) > set(b_classes):
76
+ # a contains b, partial
77
+ a_private_classes, b_private_classes = [], []
78
+ # if thres == 2 or thres == 0: # ignore extra classes in a
79
+ # a_private_classes = set(a_classes) - set(b_classes)
80
+ # if thres == 0: # ignore extra classes in a
81
+ # a_private_classes = set(a_classes) - set(b_classes)
82
+
83
+ elif set(a_classes) < set(b_classes):
84
+ # a is contained by b, open set
85
+ a_private_classes, b_private_classes = [], []
86
+ if thres == 1 or thres == 0: # ignore extra classes in b
87
+ b_private_classes = set(b_classes) - set(a_classes)
88
+
89
+ elif len(set(a_classes) & set(b_classes)) > 0:
90
+ a_private_classes, b_private_classes = [], []
91
+ if thres == 0:
92
+ # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
93
+
94
+ b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
95
+ elif thres == 1:
96
+ b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
97
+ elif thres == 2:
98
+ # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
99
+ pass
100
+
101
+ else:
102
+ return None # a has no intersection with b, none
103
+
104
+ return list(b_private_classes)
105
+
106
+
107
+ class _ABDatasetMetaInfo:
108
+ def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type):
109
+ self.name = name
110
+ self.classes = classes
111
+ self.class_aliases = class_aliases
112
+ self.shift_type = shift_type
113
+ self.task_type = task_type
114
+ self.object_type = object_type
115
+
116
+
117
+ def _get_dist_shift_type_when_source_a_to_target_b(a: _ABDatasetMetaInfo, b: _ABDatasetMetaInfo):
118
+ if b.shift_type is None:
119
+ return 'Dataset Shifts'
120
+
121
+ if a.name in b.shift_type.keys():
122
+ return b.shift_type[a.name]
123
+
124
+ mid_dataset_name = list(b.shift_type.keys())[0]
125
+ mid_dataset_meta_info = _ABDatasetMetaInfo(mid_dataset_name, *static_dataset_registery[mid_dataset_name][1:])
126
+
127
+ return _get_dist_shift_type_when_source_a_to_target_b(a, mid_dataset_meta_info) + ' + ' + list(b.shift_type.values())[0]
128
+
129
+
130
+ def _handle_all_datasets_v2(source_datasets: List[_ABDatasetMetaInfo], target_datasets: List[_ABDatasetMetaInfo], da_mode):
131
+
132
+ # 1. merge the same meaning classes
133
+ classes_info_of_all_datasets = {
134
+ d.name: (d.classes, d.class_aliases)
135
+ for d in source_datasets + target_datasets
136
+ }
137
+ final_classes_of_all_datasets, rename_map = _merge_the_same_meaning_classes(classes_info_of_all_datasets)
138
+ all_datasets_classes = copy.deepcopy(final_classes_of_all_datasets)
139
+
140
+ # print(all_datasets_known_classes)
141
+
142
+ # 2. find ignored classes according to DA mode
143
+ # source_datasets_ignore_classes, target_datasets_ignore_classes = {d.name: [] for d in source_datasets}, \
144
+ # {d.name: [] for d in target_datasets}
145
+ # source_datasets_private_classes, target_datasets_private_classes = {d.name: [] for d in source_datasets}, \
146
+ # {d.name: [] for d in target_datasets}
147
+ target_source_relationship_map = {td.name: {} for td in target_datasets}
148
+ # source_target_relationship_map = {sd.name: [] for sd in source_datasets}
149
+
150
+ # 1. construct target_source_relationship_map
151
+ for sd in source_datasets:#sd和td使列表中每一个元素(类)的实例
152
+ for td in target_datasets:
153
+ sc = all_datasets_classes[sd.name]
154
+ tc = all_datasets_classes[td.name]
155
+
156
+ if len(set(sc) & set(tc)) == 0:#只保留有相似类别的源域和目标域
157
+ continue
158
+
159
+ target_source_relationship_map[td.name][sd.name] = _get_dist_shift_type_when_source_a_to_target_b(sd, td)
160
+
161
+ # print(target_source_relationship_map)
162
+ # exit()
163
+
164
+ source_datasets_ignore_classes = {}
165
+ for td_name, v1 in target_source_relationship_map.items():
166
+ for sd_name, v2 in v1.items():
167
+ source_datasets_ignore_classes[sd_name + '|' + td_name] = []
168
+ target_datasets_ignore_classes = {d.name: [] for d in target_datasets}
169
+ target_datasets_private_classes = {d.name: [] for d in target_datasets}
170
+ # 保证对于每个目标域上的DA都符合给定的label shift
171
+ # 所以不同目标域就算对应同一个源域,该源域也可能不相同
172
+
173
+ for td_name, v1 in target_source_relationship_map.items():
174
+ sd_names = list(v1.keys())
175
+
176
+ sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
177
+ td_classes = all_datasets_classes[td_name]
178
+ ss_ignore_classes, t_ignore_classes = _find_ignore_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)#根据DA方式不同产生ignore_classes
179
+ t_private_classes = _find_private_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)
180
+
181
+ for sd_name, s_ignore_classes in zip(sd_names, ss_ignore_classes):
182
+ source_datasets_ignore_classes[sd_name + '|' + td_name] = s_ignore_classes
183
+ target_datasets_ignore_classes[td_name] = t_ignore_classes
184
+ target_datasets_private_classes[td_name] = t_private_classes
185
+
186
+ source_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in source_datasets_ignore_classes.items()}
187
+ target_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_ignore_classes.items()}
188
+ target_datasets_private_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_private_classes.items()}
189
+
190
+ # for k, v in source_datasets_ignore_classes.items():
191
+ # print(k, len(v))
192
+ # print()
193
+ # for k, v in target_datasets_ignore_classes.items():
194
+ # print(k, len(v))
195
+ # print()
196
+ # for k, v in target_datasets_private_classes.items():
197
+ # print(k, len(v))
198
+ # print()
199
+
200
+ # print(source_datasets_private_classes, target_datasets_private_classes)
201
+ # 3. reparse classes idx
202
+ # 3.1. agg all used classes
203
+ # all_used_classes = []
204
+ # all_datasets_private_class_idx_map = {}
205
+
206
+ # source_datasets_classes_idx_map = {}
207
+ # for td_name, v1 in target_source_relationship_map.items():
208
+ # for sd_name, v2 in v1.items():
209
+ # source_datasets_classes_idx_map[sd_name + '|' + td_name] = []
210
+ # target_datasets_classes_idx_map = {}
211
+
212
+ global_idx = 0
213
+ all_used_classes_idx_map = {}
214
+ # all_datasets_known_classes = {d: [] for d in final_classes_of_all_datasets.keys()}
215
+ for dataset_name, classes in all_datasets_classes.items():
216
+ if dataset_name not in target_datasets_ignore_classes.keys():
217
+ ignore_classes = [0] * 100000
218
+ for sn, sic in source_datasets_ignore_classes.items():
219
+ if sn.startswith(dataset_name):
220
+ if len(sic) < len(ignore_classes):
221
+ ignore_classes = sic
222
+ else:
223
+ ignore_classes = target_datasets_ignore_classes[dataset_name]
224
+ private_classes = [] \
225
+ if dataset_name not in target_datasets_ignore_classes.keys() else target_datasets_private_classes[dataset_name]
226
+
227
+ for c in classes:
228
+ if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c not in private_classes:
229
+ all_used_classes_idx_map[c] = global_idx
230
+ global_idx += 1
231
+
232
+ # print(all_used_classes_idx_map)
233
+
234
+ # dataset_private_class_idx_offset = 0
235
+ target_private_class_idx = global_idx
236
+ target_datasets_private_class_idx = {d: None for d in target_datasets_private_classes.keys()}
237
+
238
+ for dataset_name, classes in final_classes_of_all_datasets.items():
239
+ if dataset_name not in target_datasets_private_classes.keys():
240
+ continue
241
+
242
+ # ignore_classes = target_datasets_ignore_classes[dataset_name]
243
+ private_classes = target_datasets_private_classes[dataset_name]
244
+ # private_classes = [] \
245
+ # if dataset_name in source_datasets_private_classes.keys() else target_datasets_private_classes[dataset_name]
246
+ # for c in classes:
247
+ # if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c in private_classes:
248
+ # all_used_classes_idx_map[c] = global_idx + dataset_private_class_idx_offset
249
+
250
+ if len(private_classes) > 0:
251
+ # all_datasets_private_class_idx[dataset_name] = global_idx + dataset_private_class_idx_offset
252
+ # dataset_private_class_idx_offset += 1
253
+ # if dataset_name in source_datasets_private_classes.keys():
254
+ # if source_private_class_idx is None:
255
+ # source_private_class_idx = global_idx if target_private_class_idx is None else target_private_class_idx + 1
256
+ # all_datasets_private_class_idx[dataset_name] = source_private_class_idx
257
+ # else:
258
+ # if target_private_class_idx is None:
259
+ # target_private_class_idx = global_idx if source_private_class_idx is None else source_private_class_idx + 1
260
+ # all_datasets_private_class_idx[dataset_name] = target_private_class_idx
261
+ target_datasets_private_class_idx[dataset_name] = target_private_class_idx
262
+ target_private_class_idx += 1
263
+
264
+
265
+ # all_used_classes = sorted(set(all_used_classes), key=all_used_classes.index)
266
+ # all_used_classes_idx_map = {c: i for i, c in enumerate(all_used_classes)}
267
+
268
+ # print('rename_map', rename_map)
269
+
270
+ # 3.2 raw_class -> rename_map[raw_classes] -> all_used_classes_idx_map
271
+ all_datasets_e2e_idx_map = {}
272
+ all_datasets_e2e_class_to_idx_map = {}
273
+
274
+ for td_name, v1 in target_source_relationship_map.items():
275
+ sd_names = list(v1.keys())
276
+ sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
277
+ td_classes = all_datasets_classes[td_name]
278
+
279
+ for sd_name, sd_classes in zip(sd_names, sds_classes):
280
+ cur_e2e_idx_map = {}
281
+ cur_e2e_class_to_idx_map = {}
282
+
283
+ for raw_ci, raw_c in enumerate(sd_classes):
284
+ renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
285
+
286
+ ignore_classes = source_datasets_ignore_classes[sd_name + '|' + td_name]
287
+ if renamed_c in ignore_classes:
288
+ continue
289
+
290
+ idx = all_used_classes_idx_map[renamed_c]
291
+
292
+ cur_e2e_idx_map[raw_ci] = idx
293
+ cur_e2e_class_to_idx_map[raw_c] = idx
294
+
295
+ all_datasets_e2e_idx_map[sd_name + '|' + td_name] = cur_e2e_idx_map
296
+ all_datasets_e2e_class_to_idx_map[sd_name + '|' + td_name] = cur_e2e_class_to_idx_map
297
+ cur_e2e_idx_map = {}
298
+ cur_e2e_class_to_idx_map = {}
299
+ for raw_ci, raw_c in enumerate(td_classes):
300
+ renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
301
+
302
+ ignore_classes = target_datasets_ignore_classes[td_name]
303
+ if renamed_c in ignore_classes:
304
+ continue
305
+
306
+ if renamed_c in target_datasets_private_classes[td_name]:
307
+ idx = target_datasets_private_class_idx[td_name]
308
+ else:
309
+ idx = all_used_classes_idx_map[renamed_c]
310
+
311
+ cur_e2e_idx_map[raw_ci] = idx
312
+ cur_e2e_class_to_idx_map[raw_c] = idx
313
+
314
+ all_datasets_e2e_idx_map[td_name] = cur_e2e_idx_map
315
+ all_datasets_e2e_class_to_idx_map[td_name] = cur_e2e_class_to_idx_map
316
+
317
+ all_datasets_ignore_classes = {**source_datasets_ignore_classes, **target_datasets_ignore_classes}
318
+ # all_datasets_private_classes = {**source_datasets_private_classes, **target_datasets_private_classes}
319
+
320
+ classes_idx_set = []
321
+ for d, m in all_datasets_e2e_class_to_idx_map.items():
322
+ classes_idx_set += list(m.values())
323
+ classes_idx_set = set(classes_idx_set)
324
+ num_classes = len(classes_idx_set)
325
+
326
+ return all_datasets_ignore_classes, target_datasets_private_classes, \
327
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
328
+ target_source_relationship_map, rename_map, num_classes
329
+
330
+
331
+ def _build_scenario_info_v2(
332
+ source_datasets_name: List[str],
333
+ target_datasets_order: List[str],
334
+ da_mode: str
335
+ ):
336
+ assert da_mode in ['close_set', 'partial', 'open_set', 'universal']
337
+ da_mode = {'close_set': 'da', 'partial': 'partial_da', 'open_set': 'open_set_da', 'universal': 'universal_da'}[da_mode]
338
+
339
+ source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]#获知对应的名字和对应属性,要添加数据集时,直接register就行
340
+ target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
341
+
342
+ all_datasets_ignore_classes, target_datasets_private_classes, \
343
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
344
+ target_source_relationship_map, rename_map, num_classes \
345
+ = _handle_all_datasets_v2(source_datasets_meta_info, target_datasets_meta_info, da_mode)
346
+
347
+ return all_datasets_ignore_classes, target_datasets_private_classes, \
348
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
349
+ target_source_relationship_map, rename_map, num_classes
350
+
351
+
352
+ def build_scenario_manually_v2(
353
+ source_datasets_name: List[str],
354
+ target_datasets_order: List[str],
355
+ da_mode: str,
356
+ data_dirs: Dict[str, str],
357
+ # transforms: Optional[Dict[str, Compose]] = None
358
+ ):
359
+ configs = copy.deepcopy(locals())#返回当前局部变量
360
+
361
+ source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]
362
+ target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
363
+
364
+ all_datasets_ignore_classes, target_datasets_private_classes, \
365
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
366
+ target_source_relationship_map, rename_map, num_classes \
367
+ = _build_scenario_info_v2(source_datasets_name, target_datasets_order, da_mode)
368
+ # from rich.console import Console
369
+ # console = Console(width=10000)
370
+
371
+ # def print_obj(_o):
372
+ # # import pprint
373
+ # # s = pprint.pformat(_o, width=140, compact=True)
374
+ # console.print(_o)
375
+
376
+ # console.print('configs:', style='bold red')
377
+ # print_obj(configs)
378
+ # console.print('renamed classes:', style='bold red')
379
+ # print_obj(rename_map)
380
+ # console.print('discarded classes:', style='bold red')
381
+ # print_obj(all_datasets_ignore_classes)
382
+ # console.print('unknown classes:', style='bold red')
383
+ # print_obj(target_datasets_private_classes)
384
+ # console.print('class to index map:', style='bold red')
385
+ # print_obj(all_datasets_e2e_class_to_idx_map)
386
+ # console.print('index map:', style='bold red')
387
+ # print_obj(all_datasets_e2e_idx_map)
388
+ # console = Console()
389
+ # # console.print('class distribution:', style='bold red')
390
+ # # class_dist = {
391
+ # # k: {
392
+ # # '#known classes': len(all_datasets_known_classes[k]),
393
+ # # '#unknown classes': len(all_datasets_private_classes[k]),
394
+ # # '#discarded classes': len(all_datasets_ignore_classes[k])
395
+ # # } for k in all_datasets_ignore_classes.keys()
396
+ # # }
397
+ # # print_obj(class_dist)
398
+ # console.print('corresponding sources of each target:', style='bold red')
399
+ # print_obj(target_source_relationship_map)
400
+
401
+ # return
402
+
403
+ # res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
404
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
405
+ # for split in ['train', 'val', 'test']}
406
+ # for d in source_datasets_name}
407
+ # res_target_datasets_map = {d: {'train': get_num_limited_dataset(get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
408
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]),
409
+ # num_samples_in_each_target_domain),
410
+ # 'test': get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
411
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
412
+ # }
413
+ # for d in list(set(target_datasets_order))}
414
+
415
+ # res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
416
+ # getattr(transforms, d.split('|')[0], None),
417
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
418
+ # for split in ['train', 'val', 'test']}
419
+ # for d in all_datasets_ignore_classes.keys() if d.split('|')[0] in source_datasets_name}
420
+
421
+ # from functools import reduce
422
+ # res_offline_train_source_datasets_map = {}
423
+ # res_offline_train_source_datasets_map_names = {}
424
+
425
+ # for d in source_datasets_name:
426
+ # source_dataset_with_max_num_classes = None
427
+
428
+ # for ed_name, ed in res_source_datasets_map.items():
429
+ # if not ed_name.startswith(d):
430
+ # continue
431
+
432
+ # if source_dataset_with_max_num_classes is None:
433
+ # source_dataset_with_max_num_classes = ed
434
+ # res_offline_train_source_datasets_map_names[d] = ed_name
435
+
436
+ # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
437
+ # source_dataset_with_max_num_classes = ed
438
+ # res_offline_train_source_datasets_map_names[d] = ed_name
439
+
440
+ # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
441
+
442
+ # res_target_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
443
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
444
+ # for split in ['train', 'val', 'test']}
445
+ # for d in list(set(target_datasets_order))}
446
+
447
+ from .scenario import Scenario, DatasetMetaInfo
448
+
449
+ # test_scenario = Scenario(
450
+ # config=configs,
451
+ # offline_source_datasets_meta_info={
452
+ # d: DatasetMetaInfo(d,
453
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[res_offline_train_source_datasets_map_names[d]].items()},
454
+ # None)
455
+ # for d in source_datasets_name
456
+ # },
457
+ # offline_source_datasets={d: res_offline_train_source_datasets_map[d] for d in source_datasets_name},
458
+
459
+ # online_datasets_meta_info=[
460
+ # (
461
+ # {sd + '|' + d: DatasetMetaInfo(d,
462
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[sd + '|' + d].items()},
463
+ # None)
464
+ # for sd in target_source_relationship_map[d].keys()},
465
+ # DatasetMetaInfo(d,
466
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[d].items() if k not in target_datasets_private_classes[d]},
467
+ # target_datasets_private_class_idx[d])
468
+ # )
469
+ # for d in target_datasets_order
470
+ # ],
471
+ # online_datasets={**res_source_datasets_map, **res_target_datasets_map},
472
+ # target_domains_order=target_datasets_order,
473
+ # target_source_map=target_source_relationship_map,
474
+ # num_classes=num_classes
475
+ # )
476
+ import os
477
+ os.environ['_ZQL_NUMC'] = str(num_classes)
478
+
479
+ test_scenario = Scenario(config=configs, all_datasets_ignore_classes_map=all_datasets_ignore_classes,
480
+ all_datasets_idx_map=all_datasets_e2e_idx_map,
481
+ target_domains_order=target_datasets_order,
482
+ target_source_map=target_source_relationship_map,
483
+ all_datasets_e2e_class_to_idx_map=all_datasets_e2e_class_to_idx_map,
484
+ num_classes=num_classes)
485
+
486
+
487
+ return test_scenario
488
+
489
+
490
+ if __name__ == '__main__':
491
+ test_scenario = build_scenario_manually_v2(['CIFAR10', 'SVHN'],
492
+ ['STL10', 'MNIST', 'STL10', 'USPS', 'MNIST', 'STL10'],
493
+ 'close_set')
494
+ print(test_scenario.num_classes)
495
+
data/build_gen/merge_alias.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import L
2
+ from typing import Dict, List
3
+ from collections import Counter
4
+
5
+
6
+ def grouping(bondlist):
7
+ # reference: https://blog.csdn.net/YnagShanwen/article/details/111344386
8
+ groups = []
9
+ break1 = False
10
+ while bondlist:
11
+ pair1 = bondlist.pop(0)
12
+ a = 11111
13
+ b = 10000
14
+ while b != a:
15
+ a = b
16
+ for atomid in pair1:
17
+ for i,pair2 in enumerate(bondlist):
18
+ if atomid in pair2:
19
+ pair1 = pair1 + pair2
20
+ bondlist.pop(i)
21
+ if not bondlist:
22
+ break1 = True
23
+ break
24
+ if break1:
25
+ break
26
+ b = len(pair1)
27
+ groups.append(pair1)
28
+ return groups
29
+
30
+
31
+ def build_semantic_class_info(classes: List[str], aliases: List[List[str]]):
32
+ res = []
33
+ for c in classes:
34
+ # print(res)
35
+ if len(aliases) == 0:
36
+ res += [[c]]
37
+ else:
38
+ find_alias = False
39
+ for alias in aliases:
40
+ if c in alias:
41
+ res += [alias]
42
+ find_alias = True
43
+ break
44
+ if not find_alias:
45
+ res += [[c]]
46
+ # print(classes, res)
47
+ return res
48
+
49
+
50
+ def merge_the_same_meaning_classes(classes_info_of_all_datasets):
51
+ # print(classes_info_of_all_datasets)
52
+
53
+ semantic_classes_of_all_datasets = []
54
+ all_aliases = []
55
+ for classes, aliases in classes_info_of_all_datasets.values():
56
+ all_aliases += aliases
57
+ for classes, aliases in classes_info_of_all_datasets.values():
58
+ semantic_classes_of_all_datasets += build_semantic_class_info(classes, all_aliases)
59
+
60
+ # print(semantic_classes_of_all_datasets)
61
+
62
+ grouped_classes_of_all_datasets = grouping(semantic_classes_of_all_datasets)#匹配过后的数据
63
+
64
+ # print(grouped_classes_of_all_datasets)
65
+
66
+ # final_grouped_classes_of_all_datasets = [Counter(c).most_common()[0][0] for c in grouped_classes_of_all_datasets]
67
+ # use most common class name; if the same common, use shortest class name!
68
+ final_grouped_classes_of_all_datasets = []
69
+ for c in grouped_classes_of_all_datasets:
70
+ counter = Counter(c).most_common()
71
+ max_times = counter[0][1]
72
+ candidate_class_names = []
73
+ for item, times in counter:
74
+ if times < max_times:
75
+ break
76
+ candidate_class_names += [item]
77
+ candidate_class_names.sort(key=lambda x: len(x))
78
+
79
+ final_grouped_classes_of_all_datasets += [candidate_class_names[0]]
80
+ res = {}
81
+ res_map = {d: {} for d in classes_info_of_all_datasets.keys()}
82
+
83
+ for dataset_name, (classes, _) in classes_info_of_all_datasets.items():
84
+ final_classes = []
85
+ for c in classes:
86
+ for grouped_names, final_name in zip(grouped_classes_of_all_datasets, final_grouped_classes_of_all_datasets):
87
+ if c in grouped_names:
88
+ final_classes += [final_name]
89
+ if final_name != c:
90
+ res_map[dataset_name][c] = final_name
91
+ break
92
+ res[dataset_name] = sorted(set(final_classes), key=final_classes.index)
93
+ return res, res_map
94
+
95
+
96
+ if __name__ == '__main__':
97
+ cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
98
+ cifar10_aliases = [['automobile', 'car']]
99
+ stl10_classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
100
+
101
+ final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes({
102
+ 'CIFAR10': (cifar10_classes, cifar10_aliases),
103
+ 'STL10': (stl10_classes, [])
104
+ })
105
+
106
+ print(final_classes_of_all_datasets, rename_map)
data/build_gen/scenario.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from functools import reduce
3
+ from typing import Dict, List, Tuple
4
+ import numpy as np
5
+ import copy
6
+ from utils.common.log import logger
7
+ from ..datasets.ab_dataset import ABDataset
8
+ from ..datasets.dataset_split import train_val_split
9
+ from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
10
+ from data import get_dataset
11
+
12
+
13
+ class DatasetMetaInfo:
14
+ def __init__(self, name,
15
+ known_classes_name_idx_map, unknown_class_idx):
16
+
17
+ assert unknown_class_idx not in known_classes_name_idx_map.keys()
18
+
19
+ self.name = name
20
+ self.unknown_class_idx = unknown_class_idx
21
+ self.known_classes_name_idx_map = known_classes_name_idx_map
22
+
23
+ @property
24
+ def num_classes(self):
25
+ return len(self.known_classes_idx) + 1
26
+
27
+
28
+ class MergedDataset:
29
+ def __init__(self, datasets: List[ABDataset]):
30
+ self.datasets = datasets
31
+ self.datasets_len = [len(i) for i in self.datasets]
32
+ logger.info(f'create MergedDataset: len of datasets {self.datasets_len}')
33
+ self.datasets_cum_len = np.cumsum(self.datasets_len)
34
+
35
+ def __getitem__(self, idx):
36
+ for i, cum_len in enumerate(self.datasets_cum_len):
37
+ if idx < cum_len:
38
+ return self.datasets[i][idx - sum(self.datasets_len[0: i])]
39
+
40
+ def __len__(self):
41
+ return sum(self.datasets_len)
42
+
43
+
44
+ class IndexReturnedDataset:
45
+ def __init__(self, dataset: ABDataset):
46
+ self.dataset = dataset
47
+
48
+ def __getitem__(self, idx):
49
+ res = self.dataset[idx]
50
+
51
+ if isinstance(res, (tuple, list)):
52
+ return (*res, idx)
53
+ else:
54
+ return res, idx
55
+
56
+ def __len__(self):
57
+ return len(self.dataset)
58
+
59
+
60
+ # class Scenario:
61
+ # def __init__(self, config,
62
+ # source_datasets_meta_info: Dict[str, DatasetMetaInfo], target_datasets_meta_info: Dict[str, DatasetMetaInfo],
63
+ # target_source_map: Dict[str, Dict[str, str]],
64
+ # target_domains_order: List[str],
65
+ # source_datasets: Dict[str, Dict[str, ABDataset]], target_datasets: Dict[str, Dict[str, ABDataset]]):
66
+
67
+ # self.__config = config
68
+ # self.__source_datasets_meta_info = source_datasets_meta_info
69
+ # self.__target_datasets_meta_info = target_datasets_meta_info
70
+ # self.__target_source_map = target_source_map
71
+ # self.__target_domains_order = target_domains_order
72
+ # self.__source_datasets = source_datasets
73
+ # self.__target_datasets = target_datasets
74
+
75
+ # # 1. basic
76
+ # def get_config(self):
77
+ # return copy.deepcopy(self.__config)
78
+
79
+ # def get_task_type(self):
80
+ # return list(self.__source_datasets.values())[0]['train'].task_type
81
+
82
+ # def get_num_classes(self):
83
+ # known_classes_idx = []
84
+ # unknown_classes_idx = []
85
+ # for v in self.__source_datasets_meta_info.values():
86
+ # known_classes_idx += list(v.known_classes_name_idx_map.values())
87
+ # unknown_classes_idx += [v.unknown_class_idx]
88
+ # for v in self.__target_datasets_meta_info.values():
89
+ # known_classes_idx += list(v.known_classes_name_idx_map.values())
90
+ # unknown_classes_idx += [v.unknown_class_idx]
91
+ # unknown_classes_idx = [i for i in unknown_classes_idx if i is not None]
92
+ # # print(known_classes_idx, unknown_classes_idx)
93
+ # res = len(set(known_classes_idx)), len(set(unknown_classes_idx)), len(set(known_classes_idx + unknown_classes_idx))
94
+ # # print(res)
95
+ # assert res[0] + res[1] == res[2]
96
+ # return res
97
+
98
+ # def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
99
+ # if infinite:
100
+ # dataloader = InfiniteDataLoader(
101
+ # dataset, None, batch_size, num_workers=num_workers)
102
+ # else:
103
+ # dataloader = FastDataLoader(
104
+ # dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
105
+
106
+ # return dataloader
107
+
108
+ # def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
109
+ # from ..data.datasets.dataset_split import _SplitDataset
110
+ # dataset.dataset = _SplitDataset(dataset.dataset, indexes)
111
+ # return dataset
112
+
113
+ # def build_index_returned_dataset(self, dataset: ABDataset):
114
+ # return IndexReturnedDataset(dataset)
115
+
116
+ # # 2. source
117
+ # def get_source_datasets_meta_info(self):
118
+ # return self.__source_datasets_meta_info
119
+
120
+ # def get_source_datasets_name(self):
121
+ # return list(self.__source_datasets.keys())
122
+
123
+ # def get_merged_source_dataset(self, split):
124
+ # source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
125
+ # return MergedDataset(list(source_train_datasets.values()))
126
+
127
+ # def get_source_datasets(self, split):
128
+ # source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
129
+ # return source_train_datasets
130
+
131
+ # # 3. target **domain**
132
+ # # (do we need such API `get_ith_target_domain()`?)
133
+ # def get_target_domains_meta_info(self):
134
+ # return self.__source_datasets_meta_info
135
+
136
+ # def get_target_domains_order(self):
137
+ # return self.__target_domains_order
138
+
139
+ # def get_corr_source_datasets_name_of_target_domain(self, target_domain_name):
140
+ # return self.__target_source_map[target_domain_name]
141
+
142
+ # def get_limited_target_train_dataset(self):
143
+ # if len(self.__target_domains_order) > 1:
144
+ # raise RuntimeError('this API is only for pass-in scenario in user-defined online DA algorithm')
145
+ # return list(self.__target_datasets.values())[0]['train']
146
+
147
+ # def get_target_domains_iterator(self, split):
148
+ # for target_domain_index, target_domain_name in enumerate(self.__target_domains_order):
149
+ # target_dataset = self.__target_datasets[target_domain_name]
150
+ # target_domain_meta_info = self.__target_datasets_meta_info[target_domain_name]
151
+
152
+ # yield target_domain_index, target_domain_name, target_dataset[split], target_domain_meta_info
153
+
154
+ # # 4. permission management
155
+ # def get_sub_scenario(self, source_datasets_name, source_splits, target_domains_order, target_splits):
156
+ # def get_split(dataset, splits):
157
+ # res = {}
158
+ # for s, d in dataset.items():
159
+ # if s in splits:
160
+ # res[s] = d
161
+ # return res
162
+
163
+ # return Scenario(
164
+ # config=self.__config,
165
+ # source_datasets_meta_info={k: v for k, v in self.__source_datasets_meta_info.items() if k in source_datasets_name},
166
+ # target_datasets_meta_info={k: v for k, v in self.__target_datasets_meta_info.items() if k in target_domains_order},
167
+ # target_source_map={k: v for k, v in self.__target_source_map.items() if k in target_domains_order},
168
+ # target_domains_order=target_domains_order,
169
+ # source_datasets={k: get_split(v, source_splits) for k, v in self.__source_datasets.items() if k in source_datasets_name},
170
+ # target_datasets={k: get_split(v, target_splits) for k, v in self.__target_datasets.items() if k in target_domains_order}
171
+ # )
172
+
173
+ # def get_only_source_sub_scenario_for_exp_tracker(self):
174
+ # return self.get_sub_scenario(self.get_source_datasets_name(), ['train', 'val', 'test'], [], [])
175
+
176
+ # def get_only_source_sub_scenario_for_alg(self):
177
+ # return self.get_sub_scenario(self.get_source_datasets_name(), ['train'], [], [])
178
+
179
+ # def get_one_da_sub_scenario_for_alg(self, target_domain_name):
180
+ # return self.get_sub_scenario(self.get_corr_source_datasets_name_of_target_domain(target_domain_name),
181
+ # ['train', 'val'], [target_domain_name], ['train'])
182
+
183
+
184
+ # class Scenario:
185
+ # def __init__(self, config,
186
+
187
+ # offline_source_datasets_meta_info: Dict[str, DatasetMetaInfo],
188
+ # offline_source_datasets: Dict[str, ABDataset],
189
+
190
+ # online_datasets_meta_info: List[Tuple[Dict[str, DatasetMetaInfo], DatasetMetaInfo]],
191
+ # online_datasets: Dict[str, ABDataset],
192
+ # target_domains_order: List[str],
193
+ # target_source_map: Dict[str, Dict[str, str]],
194
+
195
+ # num_classes: int):
196
+
197
+ # self.config = config
198
+
199
+ # self.offline_source_datasets_meta_info = offline_source_datasets_meta_info
200
+ # self.offline_source_datasets = offline_source_datasets
201
+
202
+ # self.online_datasets_meta_info = online_datasets_meta_info
203
+ # self.online_datasets = online_datasets
204
+
205
+ # self.target_domains_order = target_domains_order
206
+ # self.target_source_map = target_source_map
207
+
208
+ # self.num_classes = num_classes
209
+
210
+ # def get_offline_source_datasets(self, split):
211
+ # return {n: d[split] for n, d in self.offline_source_datasets.items()}
212
+
213
+ # def get_offline_source_merged_dataset(self, split):
214
+ # return MergedDataset([d[split] for d in self.offline_source_datasets.values()])
215
+
216
+ # def get_online_current_corresponding_source_datasets(self, domain_index, split):
217
+ # cur_target_domain_name = self.target_domains_order[domain_index]
218
+ # cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
219
+ # cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
220
+ # return cur_source_datasets
221
+
222
+ # def get_online_current_corresponding_merged_source_dataset(self, domain_index, split):
223
+ # cur_target_domain_name = self.target_domains_order[domain_index]
224
+ # cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
225
+ # cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
226
+ # return MergedDataset([d for d in cur_source_datasets.values()])
227
+
228
+ # def get_online_current_target_dataset(self, domain_index, split):
229
+ # cur_target_domain_name = self.target_domains_order[domain_index]
230
+ # return self.online_datasets[cur_target_domain_name][split]
231
+
232
+ # def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int,
233
+ # infinite: bool, shuffle_when_finite: bool, to_iterator: bool):
234
+ # if infinite:
235
+ # dataloader = InfiniteDataLoader(
236
+ # dataset, None, batch_size, num_workers=num_workers)
237
+ # else:
238
+ # dataloader = FastDataLoader(
239
+ # dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
240
+
241
+ # if to_iterator:
242
+ # dataloader = iter(dataloader)
243
+
244
+ # return dataloader
245
+
246
+ # def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
247
+ # from data.datasets.dataset_split import _SplitDataset
248
+ # dataset.dataset = _SplitDataset(dataset.dataset, indexes)
249
+ # return dataset
250
+
251
+ # def build_index_returned_dataset(self, dataset: ABDataset):
252
+ # return IndexReturnedDataset(dataset)
253
+
254
+ # def get_config(self):
255
+ # return copy.deepcopy(self.config)
256
+
257
+ # def get_task_type(self):
258
+ # return list(self.online_datasets.values())[0]['train'].task_type
259
+
260
+ # def get_num_classes(self):
261
+ # return self.num_classes
262
+
263
+
264
+ class Scenario:
265
+ def __init__(self, config, all_datasets_ignore_classes_map, all_datasets_idx_map, target_domains_order, target_source_map,
266
+ all_datasets_e2e_class_to_idx_map,
267
+ num_classes):
268
+ self.config = config
269
+ self.all_datasets_ignore_classes_map = all_datasets_ignore_classes_map
270
+ self.all_datasets_idx_map = all_datasets_idx_map
271
+ self.target_domains_order = target_domains_order
272
+ self.target_source_map = target_source_map
273
+ self.all_datasets_e2e_class_to_idx_map = all_datasets_e2e_class_to_idx_map
274
+ self.num_classes = num_classes
275
+ self.cur_domain_index = 0
276
+
277
+ logger.info(f'[scenario build] # classes: {num_classes}')
278
+ logger.debug(f'[scenario build] idx map: {all_datasets_idx_map}')
279
+
280
+ def to_json(self):
281
+ return dict(
282
+ config=self.config, all_datasets_ignore_classes_map=self.all_datasets_ignore_classes_map,
283
+ all_datasets_idx_map=self.all_datasets_idx_map, target_domains_order=self.target_domains_order,
284
+ target_source_map=self.target_source_map,
285
+ all_datasets_e2e_class_to_idx_map=self.all_datasets_e2e_class_to_idx_map,
286
+ num_classes=self.num_classes
287
+ )
288
+
289
+ def __str__(self):
290
+ return f'Scenario({self.to_json()})'
291
+
292
+ def get_offline_datasets(self, transform=None):
293
+ # make source datasets which contains all unioned classes
294
+ res_offline_train_source_datasets_map = {}
295
+
296
+ from .. import get_dataset
297
+ data_dirs = self.config['data_dirs']
298
+
299
+ source_datasets_name = self.config['source_datasets_name']
300
+
301
+ # ori_datasets_map = {d: get_dataset(d, data_dirs[d], None, None, None, None) for d in source_datasets_name}
302
+ # res_source_datasets_map = {k: {split: train_val_split(copy.deepcopy(v), split, rate=0.97) for split in ['train', 'val']} for k, v in ori_datasets_map.items()}
303
+ # for ds in res_source_datasets_map.values():
304
+ # for k, v in ds.items():
305
+ # v.underlying_dataset.dataset.setSplit(k)
306
+ res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split,
307
+ transform,
308
+ self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
309
+ for split in ['train', 'val', 'test']}
310
+ for d in self.all_datasets_ignore_classes_map.keys() if d in source_datasets_name}
311
+
312
+ # for source_dataset_name in self.config['source_datasets_name']:
313
+ # source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
314
+
315
+ # # how to merge idx map?
316
+ # # 35 79 97
317
+ # idx_maps = [d['train'].idx_map for d in source_datasets]
318
+ # ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
319
+
320
+ # union_idx_map = {}
321
+ # for idx_map in idx_maps:
322
+ # for k, v in idx_map.items():
323
+ # if k not in union_idx_map:
324
+ # union_idx_map[k] = v
325
+ # else:
326
+ # assert union_idx_map[k] == v
327
+
328
+ # union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
329
+ # assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
330
+
331
+ # logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
332
+
333
+ # d = source_dataset_name
334
+ # res_offline_train_source_datasets_map[d] = {split: get_dataset(d, data_dirs[d], split,
335
+ # transform,
336
+ # union_ignore_classes, union_idx_map)
337
+ # for split in ['train', 'val', 'test']}
338
+
339
+ return res_source_datasets_map
340
+
341
+ def get_offline_datasets_args(self):
342
+ # make source datasets which contains all unioned classes
343
+ res_offline_train_source_datasets_map = {}
344
+
345
+ from .. import get_dataset
346
+ data_dirs = self.config['data_dirs']
347
+
348
+ source_datasets_name = self.config['source_datasets_name']
349
+ res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
350
+ None,
351
+ self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
352
+ for split in ['train', 'val', 'test']}
353
+ for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name}
354
+
355
+ for source_dataset_name in self.config['source_datasets_name']:
356
+ source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
357
+
358
+ # how to merge idx map?
359
+ # 35 79 97
360
+ idx_maps = [d['train'].idx_map for d in source_datasets]
361
+ ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
362
+
363
+ union_idx_map = {}
364
+ for idx_map in idx_maps:
365
+ for k, v in idx_map.items():
366
+ if k not in union_idx_map:
367
+ union_idx_map[k] = v
368
+ else:
369
+ assert union_idx_map[k] == v
370
+
371
+ union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
372
+ assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
373
+
374
+ logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
375
+
376
+ d = source_dataset_name
377
+ res_offline_train_source_datasets_map[d] = {split: dict(d, data_dirs[d], split,
378
+ None,
379
+ union_ignore_classes, union_idx_map)
380
+ for split in ['train', 'val', 'test']}
381
+
382
+ return res_offline_train_source_datasets_map
383
+
384
+ # for d in source_datasets_name:
385
+ # source_dataset_with_max_num_classes = None
386
+
387
+ # for ed_name, ed in res_source_datasets_map.items():
388
+ # if not ed_name.startswith(d):
389
+ # continue
390
+
391
+ # if source_dataset_with_max_num_classes is None:
392
+ # source_dataset_with_max_num_classes = ed
393
+ # res_offline_train_source_datasets_map_names[d] = ed_name
394
+
395
+ # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
396
+ # source_dataset_with_max_num_classes = ed
397
+ # res_offline_train_source_datasets_map_names[d] = ed_name
398
+
399
+ # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
400
+
401
+ # return res_offline_train_source_datasets_map
402
+
403
+ def get_online_ith_domain_datasets_args_for_inference(self, domain_index):
404
+ target_dataset_name = self.target_domains_order[domain_index]
405
+ # dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
406
+
407
+ if 'MM-CityscapesDet' in self.target_domains_order or 'CityscapesDet' in self.target_domains_order or 'BaiduPersonDet' in self.target_domains_order:
408
+ logger.info(f'use val split for inference test (only Det workload)')
409
+ split = 'test'
410
+ else:
411
+ split = 'train'
412
+
413
+ return dict(dataset_name=target_dataset_name,
414
+ root_dir=self.config['data_dirs'][target_dataset_name],
415
+ split=split,
416
+ transform=None,
417
+ ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
418
+ idx_map=self.all_datasets_idx_map[target_dataset_name])
419
+
420
+ def get_online_ith_domain_datasets_args_for_training(self, domain_index):
421
+ target_dataset_name = self.target_domains_order[domain_index]
422
+ source_datasets_name = list(self.target_source_map[target_dataset_name].keys())
423
+
424
+ res = {}
425
+ # dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
426
+ res[target_dataset_name] = {split: dict(dataset_name=target_dataset_name,
427
+ root_dir=self.config['data_dirs'][target_dataset_name],
428
+ split=split,
429
+ transform=None,
430
+ ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
431
+ idx_map=self.all_datasets_idx_map[target_dataset_name]) for split in ['train', 'val']}
432
+ for d in source_datasets_name:
433
+ res[d] = {split: dict(dataset_name=d,
434
+ root_dir=self.config['data_dirs'][d],
435
+ split=split,
436
+ transform=None,
437
+ ignore_classes=self.all_datasets_ignore_classes_map[d + '|' + target_dataset_name],
438
+ idx_map=self.all_datasets_idx_map[d + '|' + target_dataset_name]) for split in ['train', 'val']}
439
+
440
+ return res
441
+
442
+ def get_online_cur_domain_datasets_args_for_inference(self):
443
+ return self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
444
+
445
+ def get_online_cur_domain_datasets_args_for_training(self):
446
+ return self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
447
+
448
+ def get_online_cur_domain_datasets_for_training(self, transform=None):
449
+ res = {}
450
+ datasets_args = self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
451
+ for dataset_name, dataset_args in datasets_args.items():
452
+ res[dataset_name] = {}
453
+ for split, args in dataset_args.items():
454
+ if transform is not None:
455
+ args['transform'] = transform
456
+ dataset = get_dataset(**args)
457
+ res[dataset_name][split] = dataset
458
+ return res
459
+
460
+ def get_online_cur_domain_datasets_for_inference(self, transform=None):
461
+ datasets_args = self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
462
+ if transform is not None:
463
+ datasets_args['transform'] = transform
464
+ return get_dataset(**datasets_args)
465
+
466
+ def get_online_cur_domain_samples_for_training(self, num_samples, transform=None, collate_fn=None):
467
+ dataset = self.get_online_cur_domain_datasets_for_training(transform=transform)
468
+ dataset = dataset[self.target_domains_order[self.cur_domain_index]]['train']
469
+ return next(iter(build_dataloader(dataset, num_samples, 0, True, None, collate_fn=collate_fn)))[0]
470
+
471
+ def next_domain(self):
472
+ self.cur_domain_index += 1
473
+
data/convert_all_load_to_single_load.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ convert load-all-images-into-memory-before-training dataset
3
+ to load-when-training-dataset
4
+
5
+
6
+ """
7
+
8
+
9
+ from torchvision.datasets import CIFAR10, STL10, MNIST, USPS, SVHN
10
+ import os
11
+ import tqdm
12
+
13
+
14
+ def convert(datasets_of_split, new_dir):
15
+ img_idx = {}
16
+
17
+ for d in datasets_of_split:
18
+ for x, y in tqdm.tqdm(d, total=len(d), dynamic_ncols=True):
19
+ # print(type(x), type(y))
20
+ # break
21
+ # y = str(y)
22
+ if y not in img_idx:
23
+ img_idx[y] = -1
24
+ img_idx[y] += 1
25
+
26
+ p = os.path.join(new_dir, f'{y:06d}', f'{img_idx[y]:06d}' + '.png')
27
+ os.makedirs(os.path.dirname(p), exist_ok=True)
28
+
29
+ x.save(p)
30
+
31
+
32
+ if __name__ == '__main__':
33
+ # convert(
34
+ # [CIFAR10('/data/zql/datasets/CIFAR10', True, download=True), CIFAR10('/data/zql/datasets/CIFAR10', False, download=True)],
35
+ # '/data/zql/datasets/CIFAR10-single'
36
+ # )
37
+
38
+ # convert(
39
+ # [STL10('/data/zql/datasets/STL10', 'train', download=False), STL10('/data/zql/datasets/STL10', 'test', download=False)],
40
+ # '/data/zql/datasets/STL10-single'
41
+ # )
42
+
43
+ # convert(
44
+ # [MNIST('/data/zql/datasets/MNIST', True, download=True), MNIST('/data/zql/datasets/MNIST', False, download=True)],
45
+ # '/data/zql/datasets/MNIST-single'
46
+ # )
47
+
48
+ convert(
49
+ [SVHN('/data/zql/datasets/SVHN', 'train', download=True), SVHN('/data/zql/datasets/SVHN', 'test', download=True)],
50
+ '/data/zql/datasets/SVHN-single'
51
+ )
52
+
53
+ # convert(
54
+ # [USPS('/data/zql/datasets/USPS', True, download=False), USPS('/data/zql/datasets/USPS', False, download=False)],
55
+ # '/data/zql/datasets/USPS-single'
56
+ # )
data/convert_det_dataset_to_cls.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import ABDataset
2
+ from utils.common.data_record import read_json, write_json
3
+ from PIL import Image
4
+ import os
5
+ from utils.common.file import ensure_dir
6
+ import numpy as np
7
+ from itertools import groupby
8
+ from skimage import morphology, measure
9
+ from PIL import Image
10
+ from scipy import misc
11
+ import tqdm
12
+ from PIL import ImageFile
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ import shutil
15
+
16
+
17
+ def convert_det_dataset_to_det(coco_ann_json_path, data_dir, target_data_dir, min_img_size=224):
18
+
19
+ coco_ann = read_json(coco_ann_json_path)
20
+
21
+ img_id_to_path = {}
22
+ for img in coco_ann['images']:
23
+ img_id_to_path[img['id']] = os.path.join(data_dir, img['file_name'])
24
+
25
+ classes_imgs_id_map = {}
26
+ for ann in tqdm.tqdm(coco_ann['annotations'], total=len(coco_ann['annotations']), dynamic_ncols=True):
27
+ img_id = ann['image_id']
28
+ img_path = img_id_to_path[img_id]
29
+ img = Image.open(img_path)
30
+
31
+ bbox = ann['bbox']
32
+ if bbox[2] < min_img_size or bbox[3] < min_img_size:
33
+ continue
34
+
35
+ bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
36
+
37
+ class_idx = str(ann['category_id'])
38
+ if class_idx not in classes_imgs_id_map.keys():
39
+ classes_imgs_id_map[class_idx] = 0
40
+ target_cropped_img_path = os.path.join(target_data_dir, class_idx,
41
+ f'{classes_imgs_id_map[class_idx]}.{img_path.split(".")[-1]}')
42
+ classes_imgs_id_map[class_idx] += 1
43
+
44
+ ensure_dir(target_cropped_img_path)
45
+ img.crop(bbox).save(target_cropped_img_path)
46
+
47
+
48
+
49
+ if __name__ == '__main__':
50
+ convert_det_dataset_to_det(
51
+ coco_ann_json_path='/data/zql/datasets/coco2017/train2017/coco_ann.json',
52
+ data_dir='/data/zql/datasets/coco2017/train2017',
53
+ target_data_dir='/data/zql/datasets/coco2017_for_cls_task',
54
+ min_img_size=224
55
+ )
data/convert_seg_dataset_to_cls.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import ABDataset
2
+ from utils.common.data_record import read_json
3
+ from PIL import Image
4
+ import os
5
+ from utils.common.file import ensure_dir
6
+ import numpy as np
7
+ from itertools import groupby
8
+ from skimage import morphology, measure
9
+ from PIL import Image
10
+ from scipy import misc
11
+ import tqdm
12
+ from PIL import ImageFile
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ import shutil
15
+
16
+
17
+ def convert_seg_dataset_to_cls(seg_imgs_path, seg_labels_path, target_cls_data_dir, ignore_classes_idx, thread_i, min_img_size=224, label_after_hook=lambda x: x):
18
+ """
19
+ Reference: https://blog.csdn.net/lizaijinsheng/article/details/119889946
20
+
21
+ NOTE:
22
+ Background class should not be considered.
23
+ However, if a seg dataset has only one valid class, so that the generated cls dataset also has only one class and
24
+ the cls accuracy will be 100% forever. But we do not use the generated cls dataset alone, so it is ok.
25
+ """
26
+ assert len(seg_imgs_path) == len(seg_labels_path)
27
+
28
+ classes_imgs_id_map = {}
29
+
30
+ for seg_img_path, seg_label_path in tqdm.tqdm(zip(seg_imgs_path, seg_labels_path), total=len(seg_imgs_path),
31
+ dynamic_ncols=True, leave=False, desc=f'thread {thread_i}'):
32
+
33
+ try:
34
+ seg_img = Image.open(seg_img_path)
35
+ seg_label = Image.open(seg_label_path).convert('L')
36
+ seg_label = np.array(seg_label)
37
+ seg_label = label_after_hook(seg_label)
38
+ except Exception as e:
39
+ print(e)
40
+ print(f'file {seg_img_path} error, skip')
41
+ exit()
42
+ # seg_img = Image.open(seg_img_path)
43
+ # seg_label = Image.open(seg_label_path).convert('L')
44
+ # seg_label = np.array(seg_label)
45
+
46
+ this_img_classes = set(seg_label.reshape(-1).tolist())
47
+ # print(this_img_classes)
48
+
49
+ for class_idx in this_img_classes:
50
+ if class_idx in ignore_classes_idx:
51
+ continue
52
+
53
+ if class_idx not in classes_imgs_id_map.keys():
54
+ classes_imgs_id_map[class_idx] = 0
55
+
56
+ mask = np.zeros((seg_label.shape[0], seg_label.shape[1]), dtype=np.uint8)
57
+ mask[seg_label == class_idx] = 1
58
+ mask_without_small = morphology.remove_small_objects(mask, min_size=10, connectivity=2)
59
+ label_image = measure.label(mask_without_small)
60
+
61
+ for region in measure.regionprops(label_image):
62
+ bbox = region.bbox # (top, left, bottom, right)
63
+ bbox = [bbox[1], bbox[0], bbox[3], bbox[2]] # (left, top, right, bottom)
64
+
65
+ width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
66
+ if width < min_img_size or height < min_img_size:
67
+ continue
68
+
69
+ target_cropped_img_path = os.path.join(target_cls_data_dir, str(class_idx),
70
+ f'{classes_imgs_id_map[class_idx]}.{seg_img_path.split(".")[-1]}')
71
+ ensure_dir(target_cropped_img_path)
72
+ seg_img.crop(bbox).save(target_cropped_img_path)
73
+ # print(target_cropped_img_path)
74
+ # exit()
75
+
76
+ classes_imgs_id_map[class_idx] += 1
77
+
78
+ num_cls_imgs = 0
79
+ for k, v in classes_imgs_id_map.items():
80
+ # print(f'# class {k}: {v + 1}')
81
+ num_cls_imgs += v
82
+ # print(f'total: {num_cls_imgs}')
83
+
84
+ return classes_imgs_id_map
85
+
86
+
87
+ from concurrent.futures import ThreadPoolExecutor
88
+
89
+
90
+
91
+ # def convert_seg_dataset_to_cls_multi_thread(seg_imgs_path, seg_labels_path, target_cls_data_dir, ignore_classes_idx, num_threads):
92
+ # if os.path.exists(target_cls_data_dir):
93
+ # shutil.rmtree(target_cls_data_dir)
94
+
95
+ # assert len(seg_imgs_path) == len(seg_labels_path)
96
+ # n = len(seg_imgs_path) // num_threads
97
+
98
+ # pool = ThreadPoolExecutor(max_workers=num_threads)
99
+ # # threads = []
100
+ # futures = []
101
+ # for thread_i in range(num_threads):
102
+ # # thread = threading.Thread(target=convert_seg_dataset_to_cls,
103
+ # # args=(seg_imgs_path[thread_i * n: (thread_i + 1) * n],
104
+ # # seg_labels_path[thread_i * n: (thread_i + 1) * n],
105
+ # # target_cls_data_dir, ignore_classes_idx))
106
+ # # threads += [thread]
107
+ # future = pool.submit(convert_seg_dataset_to_cls, *(seg_imgs_path[thread_i * n: (thread_i + 1) * n],
108
+ # seg_labels_path[thread_i * n: (thread_i + 1) * n],
109
+ # target_cls_data_dir, ignore_classes_idx, thread_i))
110
+ # futures += [future]
111
+
112
+ # futures += [
113
+ # pool.submit(convert_seg_dataset_to_cls, *(seg_imgs_path[(thread_i + 1) * n: ],
114
+ # seg_labels_path[(thread_i + 1) * n: ],
115
+ # target_cls_data_dir, ignore_classes_idx, thread_i))
116
+ # ]
117
+
118
+ # for f in futures:
119
+ # f.done()
120
+
121
+ # res = []
122
+ # for f in futures:
123
+ # res += [f.result()]
124
+ # print(res[-1])
125
+
126
+ # res_dist = {}
127
+ # for r in res:
128
+ # for k, v in r.items():
129
+ # if k in res_dist.keys():
130
+ # res_dist[k] += v
131
+ # else:
132
+ # res_dist[k] = v
133
+
134
+ # print('results:')
135
+ # print(res_dist)
136
+
137
+ # pool.shutdown()
138
+
139
+
140
+
141
+ import random
142
+ def random_crop_aug(target_dir):
143
+ for class_dir in os.listdir(target_dir):
144
+ class_dir = os.path.join(target_dir, class_dir)
145
+
146
+ for img_path in os.listdir(class_dir):
147
+ img_path = os.path.join(class_dir, img_path)
148
+
149
+ img = Image.open(img_path)
150
+
151
+ w, h = img.width, img.height
152
+
153
+ for ri in range(5):
154
+ img.crop(
155
+ [
156
+ random.randint(0, w // 5),
157
+ random.randint(0, h // 5),
158
+ random.randint(w // 5 * 4, w),
159
+ random.randint(h // 5 * 4, h)
160
+ ]
161
+ ).save(
162
+ os.path.join(os.path.dirname(img_path), f'randaug_{ri}_' + os.path.basename(img_path))
163
+ )
164
+ # print(img_path)
165
+ # exit()
166
+
167
+
168
+ if __name__ == '__main__':
169
+ # SuperviselyPerson
170
+ # root_dir = '/data/zql/datasets/supervisely_person/Supervisely Person Dataset'
171
+
172
+ # images_path, labels_path = [], []
173
+ # for p in os.listdir(root_dir):
174
+ # if p.startswith('ds'):
175
+ # p1 = os.path.join(root_dir, p, 'img')
176
+ # images_path += [(p, os.path.join(p1, n)) for n in os.listdir(p1)]
177
+ # for dsi, img_p in images_path:
178
+ # target_p = os.path.join(root_dir, p, dsi, img_p.split('/')[-1])
179
+ # labels_path += [target_p]
180
+ # images_path = [i[1] for i in images_path]
181
+
182
+ # target_dir = '/data/zql/datasets/supervisely_person_for_cls_task'
183
+ # if os.path.exists(target_dir):
184
+ # shutil.rmtree(target_dir)
185
+ # convert_seg_dataset_to_cls(
186
+ # seg_imgs_path=images_path,
187
+ # seg_labels_path=labels_path,
188
+ # target_cls_data_dir=target_dir,
189
+ # ignore_classes_idx=[0, 2],
190
+ # # num_threads=8
191
+ # thread_i=0
192
+ # )
193
+
194
+ # random_crop_aug('/data/zql/datasets/supervisely_person_for_cls_task')
195
+
196
+
197
+ # GTA5
198
+ # root_dir = '/data/zql/datasets/GTA-ls-copy/GTA5'
199
+ # images_path, labels_path = [], []
200
+ # for p in os.listdir(os.path.join(root_dir, 'images')):
201
+ # p = os.path.join(root_dir, 'images', p)
202
+ # if not p.endswith('png'):
203
+ # continue
204
+ # images_path += [p]
205
+ # labels_path += [p.replace('images', 'labels_gt')]
206
+
207
+ # target_dir = '/data/zql/datasets/gta5_for_cls_task'
208
+ # if os.path.exists(target_dir):
209
+ # shutil.rmtree(target_dir)
210
+
211
+ # convert_seg_dataset_to_cls(
212
+ # seg_imgs_path=images_path,
213
+ # seg_labels_path=labels_path,
214
+ # target_cls_data_dir=target_dir,
215
+ # ignore_classes_idx=[],
216
+ # thread_i=0
217
+ # )
218
+
219
+ # cityscapes
220
+ # root_dir = '/data/zql/datasets/cityscape/'
221
+
222
+ # def _get_target_suffix(mode: str, target_type: str) -> str:
223
+ # if target_type == 'instance':
224
+ # return '{}_instanceIds.png'.format(mode)
225
+ # elif target_type == 'semantic':
226
+ # return '{}_labelIds.png'.format(mode)
227
+ # elif target_type == 'color':
228
+ # return '{}_color.png'.format(mode)
229
+ # else:
230
+ # return '{}_polygons.json'.format(mode)
231
+
232
+
233
+ # images_path, labels_path = [], []
234
+ # split = 'train'
235
+ # images_dir = os.path.join(root_dir, 'leftImg8bit', split)
236
+ # targets_dir = os.path.join(root_dir, 'gtFine', split)
237
+ # for city in os.listdir(images_dir):
238
+ # img_dir = os.path.join(images_dir, city)
239
+ # target_dir = os.path.join(targets_dir, city)
240
+ # for file_name in os.listdir(img_dir):
241
+ # target_types = []
242
+ # for t in ['semantic']:
243
+ # target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
244
+ # _get_target_suffix('gtFine', t))
245
+ # target_types.append(os.path.join(target_dir, target_name))
246
+
247
+ # images_path.append(os.path.join(img_dir, file_name))
248
+ # labels_path.append(target_types[0])
249
+
250
+ # print(images_path[0: 5], '\n', labels_path[0: 5])
251
+
252
+ # target_dir = '/data/zql/datasets/cityscapes_for_cls_task'
253
+ # if os.path.exists(target_dir):
254
+ # shutil.rmtree(target_dir)
255
+ # convert_seg_dataset_to_cls(
256
+ # seg_imgs_path=images_path,
257
+ # seg_labels_path=labels_path,
258
+ # target_cls_data_dir=target_dir,
259
+ # ignore_classes_idx=[],
260
+ # # num_threads=8
261
+ # thread_i=0
262
+ # )
263
+
264
+ # import shutil
265
+
266
+ # ignore_target_dir = '/data/zql/datasets/cityscapes_for_cls_task_ignored'
267
+
268
+ # ignore_label = 255
269
+ # raw_idx_map_in_y_transform = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
270
+ # 3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
271
+ # 7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
272
+ # 14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
273
+ # 18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
274
+ # 28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
275
+ # ignore_classes_idx = [k for k, v in raw_idx_map_in_y_transform.items() if v == ignore_label]
276
+ # ignore_classes_idx = sorted(ignore_classes_idx)
277
+
278
+ # for class_dir in os.listdir(target_dir):
279
+ # if int(class_dir) in ignore_classes_idx:
280
+ # continue
281
+ # shutil.move(
282
+ # os.path.join(target_dir, class_dir),
283
+ # os.path.join(ignore_target_dir, class_dir)
284
+ # )
285
+ # else:
286
+ # shutil.move(
287
+ # os.path.join(target_dir, class_dir),
288
+ # os.path.join(target_dir, str(raw_idx_map_in_y_transform[int(class_dir)]))
289
+ # )
290
+ # continue
291
+ # print(class_dir)
292
+ # exit()
293
+
294
+
295
+
296
+ # baidu person
297
+ # root_dir = '/data/zql/datasets/baidu_person/clean_images/'
298
+
299
+ # images_path, labels_path = [], []
300
+ # for p in os.listdir(os.path.join(root_dir, 'images')):
301
+ # images_path += [os.path.join(root_dir, 'images', p)]
302
+ # labels_path += [os.path.join(root_dir, 'profiles', p.split('.')[0] + '-profile.jpg')]
303
+
304
+ # target_dir = '/data/zql/datasets/baiduperson_for_cls_task'
305
+ # # if os.path.exists(target_dir):
306
+ # # shutil.rmtree(target_dir)
307
+
308
+ # def label_after_hook(x):
309
+ # x[x > 1] = 1
310
+ # return x
311
+
312
+ # convert_seg_dataset_to_cls(
313
+ # seg_imgs_path=images_path,
314
+ # seg_labels_path=labels_path,
315
+ # target_cls_data_dir=target_dir,
316
+ # ignore_classes_idx=[1],
317
+ # # num_threads=8
318
+ # thread_i=1,
319
+ # min_img_size=224,
320
+ # label_after_hook=label_after_hook
321
+ # )
322
+
323
+
324
+
data/convert_seg_dataset_to_det.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import ABDataset
2
+ from utils.common.data_record import read_json, write_json
3
+ from PIL import Image
4
+ import os
5
+ from utils.common.file import ensure_dir
6
+ import numpy as np
7
+ from itertools import groupby
8
+ from skimage import morphology, measure
9
+ from PIL import Image
10
+ from scipy import misc
11
+ import tqdm
12
+ from PIL import ImageFile
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ import shutil
15
+
16
+
17
+ def convert_seg_dataset_to_det(seg_imgs_path, seg_labels_path, root_dir, target_coco_ann_path, ignore_classes_idx, thread_i, min_img_size=224, label_after_hook=lambda x: x):
18
+ """
19
+ Reference: https://blog.csdn.net/lizaijinsheng/article/details/119889946
20
+
21
+ NOTE:
22
+ Background class should not be considered.
23
+ However, if a seg dataset has only one valid class, so that the generated cls dataset also has only one class and
24
+ the cls accuracy will be 100% forever. But we do not use the generated cls dataset alone, so it is ok.
25
+ """
26
+ assert len(seg_imgs_path) == len(seg_labels_path)
27
+
28
+ classes_imgs_id_map = {}
29
+
30
+ coco_ann = {
31
+ 'categories': [],
32
+ "type": "instances",
33
+ 'images': [],
34
+ 'annotations': []
35
+ }
36
+
37
+ image_id = 0
38
+ ann_id = 0
39
+
40
+ pbar = tqdm.tqdm(zip(seg_imgs_path, seg_labels_path), total=len(seg_imgs_path),
41
+ dynamic_ncols=True, leave=False, desc=f'thread {thread_i}')
42
+ for seg_img_path, seg_label_path in pbar:
43
+
44
+ try:
45
+ seg_img = Image.open(seg_img_path)
46
+ seg_label = Image.open(seg_label_path).convert('L')
47
+ seg_label = np.array(seg_label)
48
+ seg_label = label_after_hook(seg_label)
49
+ except Exception as e:
50
+ print(e)
51
+ print(f'file {seg_img_path} error, skip')
52
+ exit()
53
+ # seg_img = Image.open(seg_img_path)
54
+ # seg_label = Image.open(seg_label_path).convert('L')
55
+ # seg_label = np.array(seg_label)
56
+
57
+ image_coco_info = {'file_name': os.path.relpath(seg_img_path, root_dir), 'height': seg_img.height, 'width': seg_img.width,
58
+ 'id':image_id}
59
+ image_id += 1
60
+ coco_ann['images'] += [image_coco_info]
61
+
62
+ this_img_classes = set(seg_label.reshape(-1).tolist())
63
+ # print(this_img_classes)
64
+
65
+ for class_idx in this_img_classes:
66
+ if class_idx in ignore_classes_idx:
67
+ continue
68
+
69
+ if class_idx not in classes_imgs_id_map.keys():
70
+ classes_imgs_id_map[class_idx] = 0
71
+
72
+ mask = np.zeros((seg_label.shape[0], seg_label.shape[1]), dtype=np.uint8)
73
+ mask[seg_label == class_idx] = 1
74
+ mask_without_small = morphology.remove_small_objects(mask, min_size=10, connectivity=2)
75
+ label_image = measure.label(mask_without_small)
76
+
77
+ for region in measure.regionprops(label_image):
78
+ bbox = region.bbox # (top, left, bottom, right)
79
+ bbox = [bbox[1], bbox[0], bbox[3], bbox[2]] # (left, top, right, bottom)
80
+
81
+ width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
82
+ if width < min_img_size or height < min_img_size:
83
+ continue
84
+
85
+ # target_cropped_img_path = os.path.join(target_cls_data_dir, str(class_idx),
86
+ # f'{classes_imgs_id_map[class_idx]}.{seg_img_path.split(".")[-1]}')
87
+ # ensure_dir(target_cropped_img_path)
88
+ # seg_img.crop(bbox).save(target_cropped_img_path)
89
+ # print(target_cropped_img_path)
90
+ # exit()
91
+
92
+ ann_coco_info = {'area': width*height, 'iscrowd': 0, 'image_id':
93
+ image_id - 1, 'bbox': [bbox[0], bbox[1], width, height],
94
+ 'category_id': class_idx,
95
+ 'id': ann_id, 'ignore': 0,
96
+ 'segmentation': []}
97
+ ann_id += 1
98
+
99
+ coco_ann['annotations'] += [ann_coco_info]
100
+
101
+ classes_imgs_id_map[class_idx] += 1
102
+
103
+ pbar.set_description(f'# ann: {ann_id}')
104
+
105
+ coco_ann['categories'] = [
106
+ {'id': ci, 'name': f'class_{c}_in_seg'} for ci, c in enumerate(classes_imgs_id_map.keys())
107
+ ]
108
+ c_to_ci_map = {c: ci for ci, c in enumerate(classes_imgs_id_map.keys())}
109
+ for ann in coco_ann['annotations']:
110
+ ann['category_id'] = c_to_ci_map[
111
+ ann['category_id']
112
+ ]
113
+
114
+ write_json(target_coco_ann_path, coco_ann, indent=0, backup=True)
115
+ write_json(os.path.join(root_dir, 'coco_ann.json'), coco_ann, indent=0, backup=True)
116
+
117
+ num_cls_imgs = 0
118
+ for k, v in classes_imgs_id_map.items():
119
+ # print(f'# class {k}: {v + 1}')
120
+ num_cls_imgs += v
121
+ # print(f'total: {num_cls_imgs}')
122
+
123
+ return classes_imgs_id_map
124
+
125
+
126
+ from concurrent.futures import ThreadPoolExecutor
127
+
128
+
129
+
130
+ # def convert_seg_dataset_to_cls_multi_thread(seg_imgs_path, seg_labels_path, target_cls_data_dir, ignore_classes_idx, num_threads):
131
+ # if os.path.exists(target_cls_data_dir):
132
+ # shutil.rmtree(target_cls_data_dir)
133
+
134
+ # assert len(seg_imgs_path) == len(seg_labels_path)
135
+ # n = len(seg_imgs_path) // num_threads
136
+
137
+ # pool = ThreadPoolExecutor(max_workers=num_threads)
138
+ # # threads = []
139
+ # futures = []
140
+ # for thread_i in range(num_threads):
141
+ # # thread = threading.Thread(target=convert_seg_dataset_to_cls,
142
+ # # args=(seg_imgs_path[thread_i * n: (thread_i + 1) * n],
143
+ # # seg_labels_path[thread_i * n: (thread_i + 1) * n],
144
+ # # target_cls_data_dir, ignore_classes_idx))
145
+ # # threads += [thread]
146
+ # future = pool.submit(convert_seg_dataset_to_cls, *(seg_imgs_path[thread_i * n: (thread_i + 1) * n],
147
+ # seg_labels_path[thread_i * n: (thread_i + 1) * n],
148
+ # target_cls_data_dir, ignore_classes_idx, thread_i))
149
+ # futures += [future]
150
+
151
+ # futures += [
152
+ # pool.submit(convert_seg_dataset_to_cls, *(seg_imgs_path[(thread_i + 1) * n: ],
153
+ # seg_labels_path[(thread_i + 1) * n: ],
154
+ # target_cls_data_dir, ignore_classes_idx, thread_i))
155
+ # ]
156
+
157
+ # for f in futures:
158
+ # f.done()
159
+
160
+ # res = []
161
+ # for f in futures:
162
+ # res += [f.result()]
163
+ # print(res[-1])
164
+
165
+ # res_dist = {}
166
+ # for r in res:
167
+ # for k, v in r.items():
168
+ # if k in res_dist.keys():
169
+ # res_dist[k] += v
170
+ # else:
171
+ # res_dist[k] = v
172
+
173
+ # print('results:')
174
+ # print(res_dist)
175
+
176
+ # pool.shutdown()
177
+
178
+
179
+
180
+ # import random
181
+ # def random_crop_aug(target_dir):
182
+ # for class_dir in os.listdir(target_dir):
183
+ # class_dir = os.path.join(target_dir, class_dir)
184
+
185
+ # for img_path in os.listdir(class_dir):
186
+ # img_path = os.path.join(class_dir, img_path)
187
+
188
+ # img = Image.open(img_path)
189
+
190
+ # w, h = img.width, img.height
191
+
192
+ # for ri in range(5):
193
+ # img.crop(
194
+ # [
195
+ # random.randint(0, w // 5),
196
+ # random.randint(0, h // 5),
197
+ # random.randint(w // 5 * 4, w),
198
+ # random.randint(h // 5 * 4, h)
199
+ # ]
200
+ # ).save(
201
+ # os.path.join(os.path.dirname(img_path), f'randaug_{ri}_' + os.path.basename(img_path))
202
+ # )
203
+ # # print(img_path)
204
+ # # exit()
205
+
206
+
207
+ def post_ignore_classes(coco_ann_json_path):
208
+ # from data.datasets.object_detection.yolox_data_util.api import remap_dataset
209
+ # remap_dataset(coco_ann_json_path, [], {})
210
+ pass
211
+
212
+
213
+
214
+ if __name__ == '__main__':
215
+ # SuperviselyPerson
216
+ # root_dir = '/data/zql/datasets/supervisely_person_full_20230635/Supervisely Person Dataset'
217
+
218
+ # images_path, labels_path = [], []
219
+ # for p in os.listdir(root_dir):
220
+ # if p.startswith('ds'):
221
+ # p1 = os.path.join(root_dir, p, 'img')
222
+ # images_path += [(p, os.path.join(p1, n)) for n in os.listdir(p1)]
223
+ # for dsi, img_p in images_path:
224
+ # target_p = os.path.join(root_dir, p, dsi, img_p.split('/')[-1])
225
+ # labels_path += [target_p]
226
+ # images_path = [i[1] for i in images_path]
227
+
228
+ # target_coco_ann_path = '/data/zql/datasets/supervisely_person_for_det_task/coco_ann.json'
229
+ # if os.path.exists(target_coco_ann_path):
230
+ # os.remove(target_coco_ann_path)
231
+ # convert_seg_dataset_to_det(
232
+ # seg_imgs_path=images_path,
233
+ # seg_labels_path=labels_path,
234
+ # root_dir=root_dir,
235
+ # target_coco_ann_path=target_coco_ann_path,
236
+ # ignore_classes_idx=[0, 2],
237
+ # # num_threads=8
238
+ # thread_i=0
239
+ # )
240
+
241
+ # random_crop_aug('/data/zql/datasets/supervisely_person_for_cls_task')
242
+
243
+
244
+ # GTA5
245
+ # root_dir = '/data/zql/datasets/GTA-ls-copy/GTA5'
246
+ # images_path, labels_path = [], []
247
+ # for p in os.listdir(os.path.join(root_dir, 'images')):
248
+ # p = os.path.join(root_dir, 'images', p)
249
+ # if not p.endswith('png'):
250
+ # continue
251
+ # images_path += [p]
252
+ # labels_path += [p.replace('images', 'labels_gt')]
253
+
254
+ # target_coco_ann_path = '/data/zql/datasets/gta5_for_det_task/coco_ann.json'
255
+ # if os.path.exists(target_coco_ann_path):
256
+ # os.remove(target_coco_ann_path)
257
+
258
+ # """
259
+ # [
260
+ # 'road', 'sidewalk', 'building', 'wall',
261
+ # 'fence', 'pole', 'light', 'sign',
262
+ # 'vegetation', 'terrain', 'sky', 'people', # person
263
+ # 'rider', 'car', 'truck', 'bus', 'train',
264
+ # 'motocycle', 'bicycle'
265
+ # ]
266
+ # """
267
+ # need_classes_idx = [13, 15]
268
+ # convert_seg_dataset_to_det(
269
+ # seg_imgs_path=images_path,
270
+ # seg_labels_path=labels_path,
271
+ # root_dir=root_dir,
272
+ # target_coco_ann_path=target_coco_ann_path,
273
+ # ignore_classes_idx=[i for i in range(20) if i not in need_classes_idx],
274
+ # thread_i=0
275
+ # )
276
+
277
+ # from data.datasets.object_detection.yolox_data_util.api import remap_dataset
278
+ # new_coco_ann_json_path = remap_dataset('/data/zql/datasets/GTA-ls-copy/GTA5/coco_ann.json', [-1], {0: 0, 1:-1, 2:-1, 3: 1, 4:-1, 5:-1})
279
+ # print(new_coco_ann_json_path)
280
+
281
+ # cityscapes
282
+ # root_dir = '/data/zql/datasets/cityscape/'
283
+
284
+ # def _get_target_suffix(mode: str, target_type: str) -> str:
285
+ # if target_type == 'instance':
286
+ # return '{}_instanceIds.png'.format(mode)
287
+ # elif target_type == 'semantic':
288
+ # return '{}_labelIds.png'.format(mode)
289
+ # elif target_type == 'color':
290
+ # return '{}_color.png'.format(mode)
291
+ # else:
292
+ # return '{}_polygons.json'.format(mode)
293
+
294
+
295
+ # images_path, labels_path = [], []
296
+ # split = 'train'
297
+ # images_dir = os.path.join(root_dir, 'leftImg8bit', split)
298
+ # targets_dir = os.path.join(root_dir, 'gtFine', split)
299
+ # for city in os.listdir(images_dir):
300
+ # img_dir = os.path.join(images_dir, city)
301
+ # target_dir = os.path.join(targets_dir, city)
302
+ # for file_name in os.listdir(img_dir):
303
+ # target_types = []
304
+ # for t in ['semantic']:
305
+ # target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
306
+ # _get_target_suffix('gtFine', t))
307
+ # target_types.append(os.path.join(target_dir, target_name))
308
+
309
+ # images_path.append(os.path.join(img_dir, file_name))
310
+ # labels_path.append(target_types[0])
311
+
312
+ # # print(images_path[0: 5], '\n', labels_path[0: 5])
313
+
314
+ # target_coco_ann_path = '/data/zql/datasets/cityscape/coco_ann.json'
315
+ # # if os.path.exists(target_dir):
316
+ # # shutil.rmtree(target_dir)
317
+
318
+ # need_classes_idx = [26, 28]
319
+ # convert_seg_dataset_to_det(
320
+ # seg_imgs_path=images_path,
321
+ # seg_labels_path=labels_path,
322
+ # root_dir=root_dir,
323
+ # target_coco_ann_path=target_coco_ann_path,
324
+ # ignore_classes_idx=[i for i in range(80) if i not in need_classes_idx],
325
+ # # num_threads=8
326
+ # thread_i=0
327
+ # )
328
+
329
+ # import shutil
330
+
331
+ # ignore_target_dir = '/data/zql/datasets/cityscapes_for_cls_task_ignored'
332
+
333
+ # ignore_label = 255
334
+ # raw_idx_map_in_y_transform = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
335
+ # 3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
336
+ # 7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
337
+ # 14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
338
+ # 18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
339
+ # 28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
340
+ # ignore_classes_idx = [k for k, v in raw_idx_map_in_y_transform.items() if v == ignore_label]
341
+ # ignore_classes_idx = sorted(ignore_classes_idx)
342
+
343
+ # for class_dir in os.listdir(target_dir):
344
+ # if int(class_dir) in ignore_classes_idx:
345
+ # continue
346
+ # shutil.move(
347
+ # os.path.join(target_dir, class_dir),
348
+ # os.path.join(ignore_target_dir, class_dir)
349
+ # )
350
+ # else:
351
+ # shutil.move(
352
+ # os.path.join(target_dir, class_dir),
353
+ # os.path.join(target_dir, str(raw_idx_map_in_y_transform[int(class_dir)]))
354
+ # )
355
+ # continue
356
+ # print(class_dir)
357
+ # exit()
358
+
359
+
360
+
361
+ # baidu person
362
+ # root_dir = '/data/zql/datasets/baidu_person/clean_images/'
363
+
364
+ # images_path, labels_path = [], []
365
+ # for p in os.listdir(os.path.join(root_dir, 'images')):
366
+ # images_path += [os.path.join(root_dir, 'images', p)]
367
+ # labels_path += [os.path.join(root_dir, 'profiles', p.split('.')[0] + '-profile.jpg')]
368
+
369
+ # target_dir = '/data/zql/datasets/baiduperson_for_cls_task'
370
+ # # if os.path.exists(target_dir):
371
+ # # shutil.rmtree(target_dir)
372
+
373
+ # def label_after_hook(x):
374
+ # x[x > 1] = 1
375
+ # return x
376
+
377
+ # convert_seg_dataset_to_det(
378
+ # seg_imgs_path=images_path,
379
+ # seg_labels_path=labels_path,
380
+ # root_dir=root_dir,
381
+ # target_coco_ann_path='/data/zql/datasets/baidu_person/clean_images/coco_ann_zql.json',
382
+ # ignore_classes_idx=[1],
383
+ # # num_threads=8
384
+ # thread_i=1,
385
+ # min_img_size=224,
386
+ # label_after_hook=label_after_hook
387
+ # )
388
+
389
+
390
+ # from data.visualize import visualize_classes_in_object_detection
391
+ # from data import get_dataset
392
+ # d = get_dataset('CityscapesDet', '/data/zql/datasets/cityscape/', 'val', None, [], None)
393
+ # visualize_classes_in_object_detection(d, {'car': 0, 'bus': 1}, {}, 'debug.png')
394
+
395
+ # d = get_dataset('GTA5Det', '/data/zql/datasets/GTA-ls-copy/GTA5', 'val', None, [], None)
396
+ # visualize_classes_in_object_detection(d, {'car': 0, 'bus': 1}, {}, 'debug.png')
397
+
398
+ # d = get_dataset('BaiduPersonDet', '/data/zql/datasets/baidu_person/clean_images/', 'val', None, [], None)
399
+ # visualize_classes_in_object_detection(d, {'person': 0}, {}, 'debug.png')
data/dataloader.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # domainbed/lib/fast_data_loader.py
3
+
4
+ import torch
5
+ from .datasets.ab_dataset import ABDataset
6
+
7
+
8
+ class _InfiniteSampler(torch.utils.data.Sampler):
9
+ """Wraps another Sampler to yield an infinite stream."""
10
+
11
+ def __init__(self, sampler):
12
+ self.sampler = sampler
13
+
14
+ def __iter__(self):
15
+ while True:
16
+ for batch in self.sampler:
17
+ yield batch
18
+
19
+
20
+ class InfiniteDataLoader:
21
+ def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None):
22
+ super().__init__()
23
+
24
+ if weights:
25
+ sampler = torch.utils.data.WeightedRandomSampler(
26
+ weights, replacement=True, num_samples=batch_size
27
+ )
28
+ else:
29
+ sampler = torch.utils.data.RandomSampler(dataset, replacement=True)
30
+
31
+ batch_sampler = torch.utils.data.BatchSampler(
32
+ sampler, batch_size=batch_size, drop_last=True
33
+ )
34
+
35
+ if collate_fn is not None:
36
+ self._infinite_iterator = iter(
37
+ torch.utils.data.DataLoader(
38
+ dataset,
39
+ num_workers=num_workers,
40
+ batch_sampler=_InfiniteSampler(batch_sampler),
41
+ pin_memory=False,
42
+ collate_fn=collate_fn
43
+ )
44
+ )
45
+ else:
46
+ self._infinite_iterator = iter(
47
+ torch.utils.data.DataLoader(
48
+ dataset,
49
+ num_workers=num_workers,
50
+ batch_sampler=_InfiniteSampler(batch_sampler),
51
+ pin_memory=False
52
+ )
53
+ )
54
+ self.dataset = dataset
55
+
56
+ def __iter__(self):
57
+ while True:
58
+ yield next(self._infinite_iterator)
59
+
60
+ def __len__(self):
61
+ raise ValueError
62
+
63
+
64
+ class FastDataLoader:
65
+ """
66
+ DataLoader wrapper with slightly improved speed by not respawning worker
67
+ processes at every epoch.
68
+ """
69
+
70
+ def __init__(self, dataset, batch_size, num_workers, shuffle=False, collate_fn=None):
71
+ super().__init__()
72
+
73
+ self.num_workers = num_workers
74
+
75
+ if shuffle:
76
+ sampler = torch.utils.data.RandomSampler(dataset, replacement=False)
77
+ else:
78
+ sampler = torch.utils.data.SequentialSampler(dataset)
79
+
80
+ batch_sampler = torch.utils.data.BatchSampler(
81
+ sampler,
82
+ batch_size=batch_size,
83
+ drop_last=False,
84
+ )
85
+ if collate_fn is not None:
86
+ self._infinite_iterator = iter(
87
+ torch.utils.data.DataLoader(
88
+ dataset,
89
+ num_workers=num_workers,
90
+ batch_sampler=_InfiniteSampler(batch_sampler),
91
+ pin_memory=False,
92
+ collate_fn=collate_fn
93
+ )
94
+ )
95
+ else:
96
+ self._infinite_iterator = iter(
97
+ torch.utils.data.DataLoader(
98
+ dataset,
99
+ num_workers=num_workers,
100
+ batch_sampler=_InfiniteSampler(batch_sampler),
101
+ pin_memory=False,
102
+ )
103
+ )
104
+
105
+ self.dataset = dataset
106
+ self.batch_size = batch_size
107
+ self._length = len(batch_sampler)
108
+
109
+ def __iter__(self):
110
+ for _ in range(len(self)):
111
+ yield next(self._infinite_iterator)
112
+
113
+ def __len__(self):
114
+ return self._length
115
+
116
+
117
+ def build_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool, collate_fn=None):
118
+ assert batch_size <= len(dataset), len(dataset)
119
+ if infinite:
120
+ dataloader = InfiniteDataLoader(
121
+ dataset, None, batch_size, num_workers=num_workers, collate_fn=collate_fn)
122
+ else:
123
+ dataloader = FastDataLoader(
124
+ dataset, batch_size, num_workers, shuffle=shuffle_when_finite, collate_fn=collate_fn)
125
+
126
+ return dataloader
127
+
128
+
129
+ def get_a_batch_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
130
+ pass
131
+
data/dataset.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import Type
3
+ import torch
4
+ from torch.utils.data import TensorDataset
5
+ from torch.utils.data.dataloader import DataLoader
6
+
7
+ from .datasets.ab_dataset import ABDataset
8
+
9
+ from .datasets import * # import all datasets
10
+ from .datasets.registery import static_dataset_registery
11
+
12
+
13
+ def get_dataset(dataset_name, root_dir, split, transform=None, ignore_classes=[], idx_map=None) -> ABDataset:
14
+ dataset_cls = static_dataset_registery[dataset_name][0]
15
+ dataset = dataset_cls(root_dir, split, transform, ignore_classes, idx_map)
16
+
17
+ return dataset
18
+
19
+
20
+ def get_num_limited_dataset(dataset: ABDataset, num_samples: int, discard_label=True):
21
+ dataloader = iter(DataLoader(dataset, num_samples // 2, shuffle=True))
22
+ x, y = [], []
23
+ cur_num_samples = 0
24
+ while True:
25
+ batch = next(dataloader)
26
+ cur_x, cur_y = batch[0], batch[1]
27
+
28
+ x += [cur_x]
29
+ y += [cur_y]
30
+ cur_num_samples += cur_x.size(0)
31
+
32
+ if cur_num_samples >= num_samples:
33
+ break
34
+
35
+ x, y = torch.cat(x)[0: num_samples], torch.cat(y)[0: num_samples]
36
+ if discard_label:
37
+ new_dataset = TensorDataset(x)
38
+ else:
39
+ new_dataset = TensorDataset(x, y)
40
+
41
+ dataset.dataset = new_dataset
42
+
43
+ return dataset
data/datasets/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_classification import *
2
+ from .object_detection import *
3
+ from .semantic_segmentation import *
4
+ from .action_recognition import *
5
+
6
+ from .sentiment_classification import *
7
+ from .text_generation import *
8
+ from .machine_translation import *
9
+ from .pos_tagging import *
10
+
11
+ from .mm_image_classification import *
12
+ from .visual_question_answering import *
data/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (450 Bytes). View file
 
data/datasets/__pycache__/ab_dataset.cpython-38.pyc ADDED
Binary file (2.14 kB). View file
 
data/datasets/__pycache__/data_aug.cpython-38.pyc ADDED
Binary file (3.18 kB). View file
 
data/datasets/__pycache__/dataset_cache.cpython-38.pyc ADDED
Binary file (1.62 kB). View file