xiangbog commited on
Commit
d1125e4
·
1 Parent(s): bc63ebe

Upload model checkpoint

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. v2xverse_late_multiclass_2025_01_28_08_49_56/config.yaml +213 -0
  2. v2xverse_late_multiclass_2025_01_28_08_49_56/events.out.tfevents.1738072197.poliwag.engin.umich.edu +3 -0
  3. v2xverse_late_multiclass_2025_01_28_08_49_56/net_epoch_bestval_at14.pth +3 -0
  4. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/__init__.py +0 -0
  5. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/__pycache__/__init__.cpython-37.pyc +0 -0
  6. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/__init__.py +0 -0
  7. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/__pycache__/__init__.cpython-37.pyc +0 -0
  8. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/__pycache__/augment_utils.cpython-37.pyc +0 -0
  9. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/__pycache__/data_augmentor.cpython-37.pyc +0 -0
  10. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/augment_utils.py +88 -0
  11. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/data_augmentor.py +120 -0
  12. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__init__.py +35 -0
  13. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/__init__.cpython-37.pyc +0 -0
  14. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/early_fusion_dataset.cpython-37.pyc +0 -0
  15. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/early_multiclass_fusion_dataset.cpython-37.pyc +0 -0
  16. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/intermediate_2stage_fusion_dataset.cpython-37.pyc +0 -0
  17. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/intermediate_fusion_dataset.cpython-37.pyc +0 -0
  18. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/intermediate_heter_fusion_dataset.cpython-37.pyc +0 -0
  19. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/intermediate_multiclass_fusion_dataset.cpython-37.pyc +0 -0
  20. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/late_fusion_dataset.cpython-37.pyc +0 -0
  21. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/late_heter_fusion_dataset.cpython-37.pyc +0 -0
  22. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/late_multiclass_fusion_dataset.cpython-37.pyc +0 -0
  23. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/dairv2x_basedataset.cpython-37.pyc +0 -0
  24. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/opv2v_basedataset.cpython-37.pyc +0 -0
  25. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/v2xset_basedataset.cpython-37.pyc +0 -0
  26. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/v2xsim_basedataset.cpython-37.pyc +0 -0
  27. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/v2xverse_basedataset.cpython-37.pyc +0 -0
  28. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/dairv2x_basedataset.py +285 -0
  29. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/opv2v_basedataset.py +479 -0
  30. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/v2xset_basedataset.py +24 -0
  31. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/v2xsim_basedataset.py +238 -0
  32. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/v2xverse_basedataset.py +1118 -0
  33. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/early_fusion_dataset.py +414 -0
  34. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/early_multiclass_fusion_dataset.py +899 -0
  35. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/intermediate_2stage_fusion_dataset.py +603 -0
  36. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/intermediate_fusion_dataset.py +679 -0
  37. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/intermediate_heter_fusion_dataset.py +752 -0
  38. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/intermediate_multiclass_fusion_dataset.py +892 -0
  39. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/late_fusion_dataset.py +564 -0
  40. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/late_heter_fusion_dataset.py +565 -0
  41. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/late_multi_fusion_dataset.py +631 -0
  42. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/late_multiclass_fusion_dataset.py +1233 -0
  43. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__init__.py +27 -0
  44. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/__init__.cpython-37.pyc +0 -0
  45. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/base_postprocessor.cpython-37.pyc +0 -0
  46. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/bev_postprocessor.cpython-37.pyc +0 -0
  47. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/ciassd_postprocessor.cpython-37.pyc +0 -0
  48. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/fpvrcnn_postprocessor.cpython-37.pyc +0 -0
  49. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/uncertainty_voxel_postprocessor.cpython-37.pyc +0 -0
  50. v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/voxel_postprocessor.cpython-37.pyc +0 -0
v2xverse_late_multiclass_2025_01_28_08_49_56/config.yaml ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ comm_range: 200
2
+ data_augment:
3
+ - ALONG_AXIS_LIST:
4
+ - x
5
+ NAME: random_world_flip
6
+ - NAME: random_world_rotation
7
+ WORLD_ROT_ANGLE:
8
+ - -0.78539816
9
+ - 0.78539816
10
+ - NAME: random_world_scaling
11
+ WORLD_SCALE_RANGE:
12
+ - 0.95
13
+ - 1.05
14
+ fusion:
15
+ args:
16
+ clip_pc: false
17
+ proj_first: false
18
+ core_method: intermediatemulticlass
19
+ dataset: v2xverse
20
+ input_source:
21
+ - lidar
22
+ label_type: lidar
23
+ loss:
24
+ args:
25
+ cls_weight: 5.0
26
+ code_weights:
27
+ - 1.0
28
+ - 1.0
29
+ - 1.0
30
+ - 1.0
31
+ - 1.0
32
+ - 1.0
33
+ - 5.0
34
+ - 5.0
35
+ loc_weight: 1.0
36
+ target_assigner_config:
37
+ box_coder: ResidualCoder
38
+ cav_lidar_range: &id004
39
+ - -36
40
+ - -12
41
+ - -22
42
+ - 36
43
+ - 12
44
+ - 14
45
+ gaussian_overlap: 0.1
46
+ max_objs: 40
47
+ min_radius: 2
48
+ out_size_factor: 2
49
+ voxel_size: &id001
50
+ - 0.125
51
+ - 0.125
52
+ - 36
53
+ core_method: center_point_loss_multiclass
54
+ lr_scheduler:
55
+ core_method: multistep
56
+ gamma: 0.1
57
+ step_size:
58
+ - 8
59
+ - 15
60
+ model:
61
+ args:
62
+ anchor_number: 3
63
+ att:
64
+ feat_dim: 64
65
+ base_bev_backbone:
66
+ compression: 0
67
+ layer_nums: &id002
68
+ - 3
69
+ - 4
70
+ - 5
71
+ layer_strides:
72
+ - 2
73
+ - 2
74
+ - 2
75
+ num_filters: &id003
76
+ - 64
77
+ - 128
78
+ - 256
79
+ num_upsample_filter:
80
+ - 128
81
+ - 128
82
+ - 128
83
+ resnet: true
84
+ upsample_strides:
85
+ - 1
86
+ - 2
87
+ - 4
88
+ voxel_size: *id001
89
+ fusion_args:
90
+ agg_operator:
91
+ feature_dim: 256
92
+ mode: MAX
93
+ downsample_rate: 2
94
+ dropout_rate: 0
95
+ in_channels: 256
96
+ layer_nums: *id002
97
+ multi_scale: false
98
+ n_head: 8
99
+ num_filters: *id003
100
+ only_attention: true
101
+ voxel_size: *id001
102
+ fusion_method: max
103
+ lidar_range: *id004
104
+ max_cav: 5
105
+ multi_class: true
106
+ out_size_factor: 2
107
+ pillar_vfe:
108
+ num_filters:
109
+ - 64
110
+ use_absolute_xyz: true
111
+ use_norm: true
112
+ with_distance: false
113
+ point_pillar_scatter:
114
+ grid_size: !!python/object/apply:numpy.core.multiarray._reconstruct
115
+ args:
116
+ - !!python/name:numpy.ndarray ''
117
+ - !!python/tuple
118
+ - 0
119
+ - !!binary |
120
+ Yg==
121
+ state: !!python/tuple
122
+ - 1
123
+ - !!python/tuple
124
+ - 3
125
+ - !!python/object/apply:numpy.dtype
126
+ args:
127
+ - i8
128
+ - 0
129
+ - 1
130
+ state: !!python/tuple
131
+ - 3
132
+ - <
133
+ - null
134
+ - null
135
+ - null
136
+ - -1
137
+ - -1
138
+ - 0
139
+ - false
140
+ - !!binary |
141
+ QAIAAAAAAADAAAAAAAAAAAEAAAAAAAAA
142
+ num_features: 64
143
+ shrink_header:
144
+ dim:
145
+ - 128
146
+ input_dim: 384
147
+ kernal_size:
148
+ - 3
149
+ padding:
150
+ - 1
151
+ stride:
152
+ - 1
153
+ supervise_fusion: false
154
+ supervise_single: true
155
+ voxel_size: *id001
156
+ core_method: point_pillar_single_multiclass
157
+ name: v2xverse_late_multiclass
158
+ noise_setting: !!python/object/apply:collections.OrderedDict
159
+ - - - add_noise
160
+ - false
161
+ optimizer:
162
+ args:
163
+ eps: 1.0e-10
164
+ weight_decay: 0.0001
165
+ core_method: Adam
166
+ lr: 0.002
167
+ postprocess:
168
+ anchor_args:
169
+ D: 1
170
+ H: 192
171
+ W: 576
172
+ cav_lidar_range: *id004
173
+ feature_stride: 2
174
+ h: 1.56
175
+ l: 3.9
176
+ num: 1
177
+ r: &id005
178
+ - 0
179
+ vd: 36
180
+ vh: 0.125
181
+ vw: 0.125
182
+ w: 1.6
183
+ core_method: VoxelPostprocessor
184
+ dir_args:
185
+ anchor_yaw: *id005
186
+ dir_offset: 0.7853
187
+ num_bins: 1
188
+ gt_range: *id004
189
+ max_num: 100
190
+ nms_thresh: 0.15
191
+ order: hwl
192
+ target_args:
193
+ neg_threshold: 0.45
194
+ pos_threshold: 0.6
195
+ score_threshold: 0.2
196
+ preprocess:
197
+ args:
198
+ max_points_per_voxel: 32
199
+ max_voxel_test: 70000
200
+ max_voxel_train: 32000
201
+ voxel_size: *id001
202
+ cav_lidar_range: *id004
203
+ core_method: SpVoxelPreprocessor
204
+ root_dir: external_paths/data_root
205
+ test_dir: external_paths/data_root
206
+ train_params:
207
+ batch_size: 4
208
+ epoches: 40
209
+ eval_freq: 1
210
+ max_cav: 5
211
+ save_freq: 1
212
+ validate_dir: external_paths/data_root
213
+ yaml_parser: load_point_pillar_params
v2xverse_late_multiclass_2025_01_28_08_49_56/events.out.tfevents.1738072197.poliwag.engin.umich.edu ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac3b8a28e7fba347631b57fb22d403037b9f1fa244f0b566d60222d5c9bf5756
3
+ size 498679515
v2xverse_late_multiclass_2025_01_28_08_49_56/net_epoch_bestval_at14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba3fef03956eb6da6eb9721db6baf142f81f85ac84cd95324c1e37065d387b50
3
+ size 32820345
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/__init__.py ADDED
File without changes
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (158 Bytes). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/__init__.py ADDED
File without changes
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (168 Bytes). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/__pycache__/augment_utils.cpython-37.pyc ADDED
Binary file (2.44 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/__pycache__/data_augmentor.cpython-37.pyc ADDED
Binary file (2.96 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/augment_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: OpenPCDet
3
+
4
+ import numpy as np
5
+
6
+ from opencood.utils import common_utils
7
+
8
+
9
+ def random_flip_along_x(gt_boxes, points):
10
+ """
11
+ Args:
12
+ gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
13
+ points: (M, 3 + C)
14
+ Returns:
15
+ """
16
+ enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
17
+ if enable:
18
+ gt_boxes[:, 1] = -gt_boxes[:, 1]
19
+ gt_boxes[:, 6] = -gt_boxes[:, 6]
20
+ points[:, 1] = -points[:, 1]
21
+
22
+ if gt_boxes.shape[1] > 7:
23
+ gt_boxes[:, 8] = -gt_boxes[:, 8]
24
+
25
+ return gt_boxes, points
26
+
27
+
28
+ def random_flip_along_y(gt_boxes, points):
29
+ """
30
+ Args:
31
+ gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
32
+ points: (M, 3 + C)
33
+ Returns:
34
+ """
35
+ enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
36
+ if enable:
37
+ gt_boxes[:, 0] = -gt_boxes[:, 0]
38
+ gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi)
39
+ points[:, 0] = -points[:, 0]
40
+
41
+ if gt_boxes.shape[1] > 7:
42
+ gt_boxes[:, 7] = -gt_boxes[:, 7]
43
+
44
+ return gt_boxes, points
45
+
46
+
47
+ def global_rotation(gt_boxes, points, rot_range):
48
+ """
49
+ Args:
50
+ gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
51
+ points: (M, 3 + C),
52
+ rot_range: [min, max]
53
+ Returns:
54
+ """
55
+ noise_rotation = np.random.uniform(rot_range[0],
56
+ rot_range[1])
57
+ points = common_utils.rotate_points_along_z(points[np.newaxis, :, :],
58
+ np.array([noise_rotation]))[0]
59
+
60
+ gt_boxes[:, 0:3] = \
61
+ common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3],
62
+ np.array([noise_rotation]))[0]
63
+ gt_boxes[:, 6] += noise_rotation
64
+
65
+ if gt_boxes.shape[1] > 7:
66
+ gt_boxes[:, 7:9] = common_utils.rotate_points_along_z(
67
+ np.hstack((gt_boxes[:, 7:9], np.zeros((gt_boxes.shape[0], 1))))[
68
+ np.newaxis, :, :],
69
+ np.array([noise_rotation]))[0][:, 0:2]
70
+
71
+ return gt_boxes, points
72
+
73
+
74
+ def global_scaling(gt_boxes, points, scale_range):
75
+ """
76
+ Args:
77
+ gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading]
78
+ points: (M, 3 + C),
79
+ scale_range: [min, max]
80
+ Returns:
81
+ """
82
+ if scale_range[1] - scale_range[0] < 1e-3:
83
+ return gt_boxes, points
84
+ noise_scale = np.random.uniform(scale_range[0], scale_range[1])
85
+ points[:, :3] *= noise_scale
86
+ gt_boxes[:, :6] *= noise_scale
87
+
88
+ return gt_boxes, points
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/augmentor/data_augmentor.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Class for data augmentation
4
+ """
5
+ # Author: Runsheng Xu <[email protected]>
6
+ # License: TDG-Attribution-NonCommercial-NoDistrib
7
+
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+
12
+ from opencood.data_utils.augmentor import augment_utils
13
+
14
+
15
+ class DataAugmentor(object):
16
+ """
17
+ Data Augmentor.
18
+
19
+ Parameters
20
+ ----------
21
+ augment_config : list
22
+ A list of augmentation configuration.
23
+
24
+ Attributes
25
+ ----------
26
+ data_augmentor_queue : list
27
+ The list of data augmented functions.
28
+ """
29
+
30
+ def __init__(self, augment_config, train=True):
31
+ self.data_augmentor_queue = []
32
+ self.train = train
33
+
34
+ for cur_cfg in augment_config:
35
+ cur_augmentor = getattr(self, cur_cfg['NAME'])(config=cur_cfg)
36
+ self.data_augmentor_queue.append(cur_augmentor)
37
+
38
+ def random_world_flip(self, data_dict=None, config=None):
39
+ if data_dict is None:
40
+ return partial(self.random_world_flip, config=config)
41
+
42
+ gt_boxes, gt_mask, points = data_dict['object_bbx_center'], \
43
+ data_dict['object_bbx_mask'], \
44
+ data_dict['lidar_np']
45
+ gt_boxes_valid = gt_boxes[gt_mask == 1]
46
+
47
+ for cur_axis in config['ALONG_AXIS_LIST']:
48
+ assert cur_axis in ['x', 'y']
49
+ gt_boxes_valid, points = getattr(augment_utils,
50
+ 'random_flip_along_%s' % cur_axis)(
51
+ gt_boxes_valid, points,
52
+ )
53
+
54
+ gt_boxes[:gt_boxes_valid.shape[0], :] = gt_boxes_valid
55
+
56
+ data_dict['object_bbx_center'] = gt_boxes
57
+ data_dict['object_bbx_mask'] = gt_mask
58
+ data_dict['lidar_np'] = points
59
+
60
+ return data_dict
61
+
62
+ def random_world_rotation(self, data_dict=None, config=None):
63
+ if data_dict is None:
64
+ return partial(self.random_world_rotation, config=config)
65
+
66
+ rot_range = config['WORLD_ROT_ANGLE']
67
+ if not isinstance(rot_range, list):
68
+ rot_range = [-rot_range, rot_range]
69
+
70
+ gt_boxes, gt_mask, points = data_dict['object_bbx_center'], \
71
+ data_dict['object_bbx_mask'], \
72
+ data_dict['lidar_np']
73
+ gt_boxes_valid = gt_boxes[gt_mask == 1]
74
+ gt_boxes_valid, points = augment_utils.global_rotation(
75
+ gt_boxes_valid, points, rot_range=rot_range
76
+ )
77
+ gt_boxes[:gt_boxes_valid.shape[0], :] = gt_boxes_valid
78
+
79
+ data_dict['object_bbx_center'] = gt_boxes
80
+ data_dict['object_bbx_mask'] = gt_mask
81
+ data_dict['lidar_np'] = points
82
+
83
+ return data_dict
84
+
85
+ def random_world_scaling(self, data_dict=None, config=None):
86
+ if data_dict is None:
87
+ return partial(self.random_world_scaling, config=config)
88
+
89
+ gt_boxes, gt_mask, points = data_dict['object_bbx_center'], \
90
+ data_dict['object_bbx_mask'], \
91
+ data_dict['lidar_np']
92
+ gt_boxes_valid = gt_boxes[gt_mask == 1]
93
+
94
+ gt_boxes_valid, points = augment_utils.global_scaling(
95
+ gt_boxes_valid, points, config['WORLD_SCALE_RANGE']
96
+ )
97
+ gt_boxes[:gt_boxes_valid.shape[0], :] = gt_boxes_valid
98
+
99
+ data_dict['object_bbx_center'] = gt_boxes
100
+ data_dict['object_bbx_mask'] = gt_mask
101
+ data_dict['lidar_np'] = points
102
+
103
+ return data_dict
104
+
105
+ def forward(self, data_dict):
106
+ """
107
+ Args:
108
+ data_dict:
109
+ points: (N, 3 + C_in)
110
+ gt_boxes: optional, (N, 7) [x, y, z, dx, dy, dz, heading]
111
+ gt_names: optional, (N), string
112
+ ...
113
+
114
+ Returns:
115
+ """
116
+ if self.train:
117
+ for cur_augmentor in self.data_augmentor_queue:
118
+ data_dict = cur_augmentor(data_dict=data_dict)
119
+
120
+ return data_dict
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from opencood.data_utils.datasets.late_fusion_dataset import getLateFusionDataset
2
+ from opencood.data_utils.datasets.late_heter_fusion_dataset import getLateheterFusionDataset
3
+ from opencood.data_utils.datasets.late_multiclass_fusion_dataset import getLatemulticlassFusionDataset
4
+ from opencood.data_utils.datasets.early_fusion_dataset import getEarlyFusionDataset
5
+ from opencood.data_utils.datasets.intermediate_fusion_dataset import getIntermediateFusionDataset
6
+ from opencood.data_utils.datasets.intermediate_multiclass_fusion_dataset import getIntermediatemulticlassFusionDataset
7
+ from opencood.data_utils.datasets.intermediate_2stage_fusion_dataset import getIntermediate2stageFusionDataset
8
+ from opencood.data_utils.datasets.intermediate_heter_fusion_dataset import getIntermediateheterFusionDataset
9
+ from opencood.data_utils.datasets.basedataset.opv2v_basedataset import OPV2VBaseDataset
10
+ from opencood.data_utils.datasets.basedataset.v2xsim_basedataset import V2XSIMBaseDataset
11
+ from opencood.data_utils.datasets.basedataset.dairv2x_basedataset import DAIRV2XBaseDataset
12
+ from opencood.data_utils.datasets.basedataset.v2xset_basedataset import V2XSETBaseDataset
13
+ from opencood.data_utils.datasets.basedataset.v2xverse_basedataset import V2XVERSEBaseDataset
14
+ from opencood.data_utils.datasets.late_multiclass_fusion_dataset import getLatemulticlassFusionDataset
15
+ from opencood.data_utils.datasets.early_multiclass_fusion_dataset import getEarlymulticlassFusionDataset
16
+
17
+ def build_dataset(dataset_cfg, visualize=False, train=True):
18
+ fusion_name = dataset_cfg['fusion']['core_method']
19
+ dataset_name = dataset_cfg['fusion']['dataset']
20
+
21
+ assert fusion_name in ['late', 'lateheter', 'intermediate', 'intermediate2stage', 'intermediateheter', 'intermediatemulticlass', 'early', 'latemulticlass', 'earlymulticlass']
22
+ assert dataset_name in ['opv2v', 'v2xsim', 'dairv2x', 'v2xset', 'v2xverse']
23
+
24
+ fusion_dataset_func = "get" + fusion_name.capitalize() + "FusionDataset"
25
+ fusion_dataset_func = eval(fusion_dataset_func)
26
+ base_dataset_cls = dataset_name.upper() + "BaseDataset"
27
+ base_dataset_cls = eval(base_dataset_cls)
28
+
29
+ dataset = fusion_dataset_func(base_dataset_cls)(
30
+ params=dataset_cfg,
31
+ visualize=visualize,
32
+ train=train
33
+ )
34
+
35
+ return dataset
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (2.34 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/early_fusion_dataset.cpython-37.pyc ADDED
Binary file (9.46 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/early_multiclass_fusion_dataset.cpython-37.pyc ADDED
Binary file (19 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/intermediate_2stage_fusion_dataset.cpython-37.pyc ADDED
Binary file (12.6 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/intermediate_fusion_dataset.cpython-37.pyc ADDED
Binary file (14.6 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/intermediate_heter_fusion_dataset.cpython-37.pyc ADDED
Binary file (16.2 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/intermediate_multiclass_fusion_dataset.cpython-37.pyc ADDED
Binary file (19 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/late_fusion_dataset.cpython-37.pyc ADDED
Binary file (11.9 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/late_heter_fusion_dataset.cpython-37.pyc ADDED
Binary file (12.8 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/__pycache__/late_multiclass_fusion_dataset.cpython-37.pyc ADDED
Binary file (24 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/dairv2x_basedataset.cpython-37.pyc ADDED
Binary file (9.18 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/opv2v_basedataset.cpython-37.pyc ADDED
Binary file (12.3 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/v2xset_basedataset.cpython-37.pyc ADDED
Binary file (1.38 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/v2xsim_basedataset.cpython-37.pyc ADDED
Binary file (6.29 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/__pycache__/v2xverse_basedataset.cpython-37.pyc ADDED
Binary file (31.3 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/dairv2x_basedataset.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ import cv2
4
+ import h5py
5
+ import torch
6
+ import numpy as np
7
+ from functools import partial
8
+ from torch.utils.data import Dataset
9
+ from PIL import Image
10
+ import random
11
+ import opencood.utils.pcd_utils as pcd_utils
12
+ from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
13
+ from opencood.hypes_yaml.yaml_utils import load_yaml
14
+ from opencood.utils.pcd_utils import downsample_lidar_minimum
15
+ from opencood.utils.camera_utils import load_camera_data, load_intrinsic_DAIR_V2X
16
+ from opencood.utils.common_utils import read_json
17
+ from opencood.utils.transformation_utils import tfm_to_pose, rot_and_trans_to_trasnformation_matrix
18
+ from opencood.utils.transformation_utils import veh_side_rot_and_trans_to_trasnformation_matrix
19
+ from opencood.utils.transformation_utils import inf_side_rot_and_trans_to_trasnformation_matrix
20
+ from opencood.data_utils.pre_processor import build_preprocessor
21
+ from opencood.data_utils.post_processor import build_postprocessor
22
+
23
+ class DAIRV2XBaseDataset(Dataset):
24
+ def __init__(self, params, visualize, train=True):
25
+ self.params = params
26
+ self.visualize = visualize
27
+ self.train = train
28
+
29
+ self.pre_processor = build_preprocessor(params["preprocess"], train)
30
+ self.post_processor = build_postprocessor(params["postprocess"], train)
31
+ self.post_processor.generate_gt_bbx = self.post_processor.generate_gt_bbx_by_iou
32
+ if 'data_augment' in params: # late and early
33
+ self.data_augmentor = DataAugmentor(params['data_augment'], train)
34
+ else: # intermediate
35
+ self.data_augmentor = None
36
+
37
+ if 'clip_pc' in params['fusion']['args'] and params['fusion']['args']['clip_pc']:
38
+ self.clip_pc = True
39
+ else:
40
+ self.clip_pc = False
41
+
42
+ if 'train_params' not in params or 'max_cav' not in params['train_params']:
43
+ self.max_cav = 2
44
+ else:
45
+ self.max_cav = params['train_params']['max_cav']
46
+
47
+ self.load_lidar_file = True if 'lidar' in params['input_source'] or self.visualize else False
48
+ self.load_camera_file = True if 'camera' in params['input_source'] else False
49
+ self.load_depth_file = True if 'depth' in params['input_source'] else False
50
+
51
+ assert self.load_depth_file is False
52
+
53
+ self.label_type = params['label_type'] # 'lidar' or 'camera'
54
+ self.generate_object_center = self.generate_object_center_lidar if self.label_type == "lidar" \
55
+ else self.generate_object_center_camera
56
+
57
+ if self.load_camera_file:
58
+ self.data_aug_conf = params["fusion"]["args"]["data_aug_conf"]
59
+
60
+ if self.train:
61
+ split_dir = params['root_dir']
62
+ else:
63
+ split_dir = params['validate_dir']
64
+
65
+ self.root_dir = params['data_dir']
66
+
67
+ self.split_info = read_json(split_dir)
68
+ co_datainfo = read_json(os.path.join(self.root_dir, 'cooperative/data_info.json'))
69
+ self.co_data = OrderedDict()
70
+ for frame_info in co_datainfo:
71
+ veh_frame_id = frame_info['vehicle_image_path'].split("/")[-1].replace(".jpg", "")
72
+ self.co_data[veh_frame_id] = frame_info
73
+
74
+ if "noise_setting" not in self.params:
75
+ self.params['noise_setting'] = OrderedDict()
76
+ self.params['noise_setting']['add_noise'] = False
77
+
78
+ def reinitialize(self):
79
+ pass
80
+
81
+ def retrieve_base_data(self, idx):
82
+ """
83
+ Given the index, return the corresponding data.
84
+ NOTICE!
85
+ It is different from Intermediate Fusion and Early Fusion
86
+ Label is not cooperative and loaded for both veh side and inf side.
87
+ Parameters
88
+ ----------
89
+ idx : int
90
+ Index given by dataloader.
91
+ Returns
92
+ -------
93
+ data : dict
94
+ The dictionary contains loaded yaml params and lidar data for
95
+ each cav.
96
+ """
97
+ veh_frame_id = self.split_info[idx]
98
+ frame_info = self.co_data[veh_frame_id]
99
+ system_error_offset = frame_info["system_error_offset"]
100
+ data = OrderedDict()
101
+
102
+ data[0] = OrderedDict()
103
+ data[0]['ego'] = True
104
+ data[1] = OrderedDict()
105
+ data[1]['ego'] = False
106
+
107
+ data[0]['params'] = OrderedDict()
108
+ data[1]['params'] = OrderedDict()
109
+
110
+ # pose of agent
111
+ lidar_to_novatel = read_json(os.path.join(self.root_dir,'vehicle-side/calib/lidar_to_novatel/'+str(veh_frame_id)+'.json'))
112
+ novatel_to_world = read_json(os.path.join(self.root_dir,'vehicle-side/calib/novatel_to_world/'+str(veh_frame_id)+'.json'))
113
+ transformation_matrix = veh_side_rot_and_trans_to_trasnformation_matrix(lidar_to_novatel, novatel_to_world)
114
+ data[0]['params']['lidar_pose'] = tfm_to_pose(transformation_matrix)
115
+
116
+ inf_frame_id = frame_info['infrastructure_image_path'].split("/")[-1].replace(".jpg", "")
117
+ virtuallidar_to_world = read_json(os.path.join(self.root_dir,'infrastructure-side/calib/virtuallidar_to_world/'+str(inf_frame_id)+'.json'))
118
+ transformation_matrix = inf_side_rot_and_trans_to_trasnformation_matrix(virtuallidar_to_world, system_error_offset)
119
+ data[1]['params']['lidar_pose'] = tfm_to_pose(transformation_matrix)
120
+
121
+ data[0]['params']['vehicles_front'] = read_json(os.path.join(self.root_dir,frame_info['cooperative_label_path'].replace("label_world", "label_world_backup")))
122
+ data[0]['params']['vehicles_all'] = read_json(os.path.join(self.root_dir,frame_info['cooperative_label_path']))
123
+
124
+ data[1]['params']['vehicles_front'] = [] # we only load cooperative label in vehicle side
125
+ data[1]['params']['vehicles_all'] = [] # we only load cooperative label in vehicle side
126
+
127
+ if self.load_camera_file:
128
+ data[0]['camera_data'] = load_camera_data([os.path.join(self.root_dir, frame_info["vehicle_image_path"])])
129
+ data[0]['params']['camera0'] = OrderedDict()
130
+ data[0]['params']['camera0']['extrinsic'] = rot_and_trans_to_trasnformation_matrix( \
131
+ read_json(os.path.join(self.root_dir, 'vehicle-side/calib/lidar_to_camera/'+str(veh_frame_id)+'.json')))
132
+ data[0]['params']['camera0']['intrinsic'] = load_intrinsic_DAIR_V2X( \
133
+ read_json(os.path.join(self.root_dir, 'vehicle-side/calib/camera_intrinsic/'+str(veh_frame_id)+'.json')))
134
+
135
+ data[1]['camera_data']= load_camera_data([os.path.join(self.root_dir,frame_info["infrastructure_image_path"])])
136
+ data[1]['params']['camera0'] = OrderedDict()
137
+ data[1]['params']['camera0']['extrinsic'] = rot_and_trans_to_trasnformation_matrix( \
138
+ read_json(os.path.join(self.root_dir, 'infrastructure-side/calib/virtuallidar_to_camera/'+str(inf_frame_id)+'.json')))
139
+ data[1]['params']['camera0']['intrinsic'] = load_intrinsic_DAIR_V2X( \
140
+ read_json(os.path.join(self.root_dir, 'infrastructure-side/calib/camera_intrinsic/'+str(inf_frame_id)+'.json')))
141
+
142
+
143
+ if self.load_lidar_file or self.visualize:
144
+ data[0]['lidar_np'], _ = pcd_utils.read_pcd(os.path.join(self.root_dir,frame_info["vehicle_pointcloud_path"]))
145
+ data[1]['lidar_np'], _ = pcd_utils.read_pcd(os.path.join(self.root_dir,frame_info["infrastructure_pointcloud_path"]))
146
+
147
+
148
+ # Label for single side
149
+ data[0]['params']['vehicles_single_front'] = read_json(os.path.join(self.root_dir, \
150
+ 'vehicle-side/label/lidar_backup/{}.json'.format(veh_frame_id)))
151
+ data[0]['params']['vehicles_single_all'] = read_json(os.path.join(self.root_dir, \
152
+ 'vehicle-side/label/lidar/{}.json'.format(veh_frame_id)))
153
+ data[1]['params']['vehicles_single_front'] = read_json(os.path.join(self.root_dir, \
154
+ 'infrastructure-side/label/virtuallidar/{}.json'.format(inf_frame_id)))
155
+ data[1]['params']['vehicles_single_all'] = read_json(os.path.join(self.root_dir, \
156
+ 'infrastructure-side/label/virtuallidar/{}.json'.format(inf_frame_id)))
157
+
158
+ if getattr(self, "heterogeneous", False):
159
+ self.generate_object_center_lidar = \
160
+ partial(self.generate_object_center_single_hetero, modality='lidar')
161
+ self.generate_object_center_camera = \
162
+ partial(self.generate_object_center_single_hetero, modality='camera')
163
+
164
+ # by default
165
+ data[0]['modality_name'] = 'm1'
166
+ data[1]['modality_name'] = 'm2'
167
+ # veh cam inf lidar
168
+ data[0]['modality_name'] = 'm2'
169
+ data[1]['modality_name'] = 'm1'
170
+
171
+ if self.train: # randomly choose LiDAR or Camera to be Ego
172
+ p = np.random.rand()
173
+ if p > 0.5:
174
+ data[0], data[1] = data[1], data[0]
175
+ data[0]['ego'] = True
176
+ data[1]['ego'] = False
177
+ else:
178
+ # evaluate, the agent of ego modality should be ego
179
+ if self.adaptor.mapping_dict[data[0]['modality_name']] not in self.ego_modality and \
180
+ self.adaptor.mapping_dict[data[1]['modality_name']] in self.ego_modality:
181
+ data[0], data[1] = data[1], data[0]
182
+ data[0]['ego'] = True
183
+ data[1]['ego'] = False
184
+
185
+ data[0]['modality_name'] = self.adaptor.reassign_cav_modality(data[0]['modality_name'], 0)
186
+ data[1]['modality_name'] = self.adaptor.reassign_cav_modality(data[1]['modality_name'], 1)
187
+
188
+
189
+ return data
190
+
191
+
192
+ def __len__(self):
193
+ return len(self.split_info)
194
+
195
+ def __getitem__(self, idx):
196
+ pass
197
+
198
+
199
+ def generate_object_center_lidar(self,
200
+ cav_contents,
201
+ reference_lidar_pose):
202
+ """
203
+ reference lidar 's coordinate
204
+ """
205
+ for cav_content in cav_contents:
206
+ cav_content['params']['vehicles'] = cav_content['params']['vehicles_all']
207
+ return self.post_processor.generate_object_center_dairv2x(cav_contents,
208
+ reference_lidar_pose)
209
+
210
+ def generate_object_center_camera(self,
211
+ cav_contents,
212
+ reference_lidar_pose):
213
+ """
214
+ reference lidar 's coordinate
215
+ """
216
+ for cav_content in cav_contents:
217
+ cav_content['params']['vehicles'] = cav_content['params']['vehicles_front']
218
+ return self.post_processor.generate_object_center_dairv2x(cav_contents,
219
+ reference_lidar_pose)
220
+
221
+ ### Add new func for single side
222
+ def generate_object_center_single(self,
223
+ cav_contents,
224
+ reference_lidar_pose,
225
+ **kwargs):
226
+ """
227
+ veh or inf 's coordinate.
228
+
229
+ reference_lidar_pose is of no use.
230
+ """
231
+ suffix = "_single"
232
+ for cav_content in cav_contents:
233
+ cav_content['params']['vehicles_single'] = \
234
+ cav_content['params']['vehicles_single_front'] if self.label_type == 'camera' else \
235
+ cav_content['params']['vehicles_single_all']
236
+ return self.post_processor.generate_object_center_dairv2x_single(cav_contents, suffix)
237
+
238
+ ### Add for heterogeneous, transforming the single label from self coord. to ego coord.
239
+ def generate_object_center_single_hetero(self,
240
+ cav_contents,
241
+ reference_lidar_pose,
242
+ modality):
243
+ """
244
+ loading the object from single agent.
245
+
246
+ The same as *generate_object_center_single*, but it will transform the object to reference(ego) coordinate,
247
+ using reference_lidar_pose.
248
+ """
249
+ suffix = "_single"
250
+ for cav_content in cav_contents:
251
+ cav_content['params']['vehicles_single'] = \
252
+ cav_content['params']['vehicles_single_front'] if modality == 'camera' else \
253
+ cav_content['params']['vehicles_single_all']
254
+ return self.post_processor.generate_object_center_dairv2x_single_hetero(cav_contents, reference_lidar_pose, suffix)
255
+
256
+
257
+ def get_ext_int(self, params, camera_id):
258
+ lidar_to_camera = params["camera%d" % camera_id]['extrinsic'].astype(np.float32) # R_cw
259
+ camera_to_lidar = np.linalg.inv(lidar_to_camera) # R_wc
260
+ camera_intrinsic = params["camera%d" % camera_id]['intrinsic'].astype(np.float32
261
+ )
262
+ return camera_to_lidar, camera_intrinsic
263
+
264
+ def augment(self, lidar_np, object_bbx_center, object_bbx_mask):
265
+ """
266
+ Given the raw point cloud, augment by flipping and rotation.
267
+ Parameters
268
+ ----------
269
+ lidar_np : np.ndarray
270
+ (n, 4) shape
271
+ object_bbx_center : np.ndarray
272
+ (n, 7) shape to represent bbx's x, y, z, h, w, l, yaw
273
+ object_bbx_mask : np.ndarray
274
+ Indicate which elements in object_bbx_center are padded.
275
+ """
276
+ tmp_dict = {'lidar_np': lidar_np,
277
+ 'object_bbx_center': object_bbx_center,
278
+ 'object_bbx_mask': object_bbx_mask}
279
+ tmp_dict = self.data_augmentor.forward(tmp_dict)
280
+
281
+ lidar_np = tmp_dict['lidar_np']
282
+ object_bbx_center = tmp_dict['object_bbx_center']
283
+ object_bbx_mask = tmp_dict['object_bbx_mask']
284
+
285
+ return lidar_np, object_bbx_center, object_bbx_mask
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/opv2v_basedataset.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from collections import OrderedDict
4
+ import cv2
5
+ import h5py
6
+ import torch
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ from PIL import Image
10
+ import json
11
+ import random
12
+ import opencood.utils.pcd_utils as pcd_utils
13
+ from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
14
+ from opencood.hypes_yaml.yaml_utils import load_yaml
15
+ from opencood.utils.camera_utils import load_camera_data
16
+ from opencood.utils.transformation_utils import x1_to_x2
17
+ from opencood.data_utils.pre_processor import build_preprocessor
18
+ from opencood.data_utils.post_processor import build_postprocessor
19
+
20
+ class OPV2VBaseDataset(Dataset):
21
+ def __init__(self, params, visualize, train=True):
22
+ self.params = params
23
+ self.visualize = visualize
24
+ self.train = train
25
+
26
+ self.pre_processor = build_preprocessor(params["preprocess"], train)
27
+ self.post_processor = build_postprocessor(params["postprocess"], train)
28
+ if 'data_augment' in params: # late and early
29
+ self.data_augmentor = DataAugmentor(params['data_augment'], train)
30
+ else: # intermediate
31
+ self.data_augmentor = None
32
+
33
+ if self.train:
34
+ root_dir = params['root_dir']
35
+ else:
36
+ root_dir = params['validate_dir']
37
+ self.root_dir = root_dir
38
+
39
+ print("Dataset dir:", root_dir)
40
+
41
+ if 'train_params' not in params or \
42
+ 'max_cav' not in params['train_params']:
43
+ self.max_cav = 5
44
+ else:
45
+ self.max_cav = params['train_params']['max_cav']
46
+
47
+ self.load_lidar_file = True if 'lidar' in params['input_source'] or self.visualize else False
48
+ self.load_camera_file = True if 'camera' in params['input_source'] else False
49
+ self.load_depth_file = True if 'depth' in params['input_source'] else False
50
+
51
+ self.label_type = params['label_type'] # 'lidar' or 'camera'
52
+ self.generate_object_center = self.generate_object_center_lidar if self.label_type == "lidar" \
53
+ else self.generate_object_center_camera
54
+ self.generate_object_center_single = self.generate_object_center # will it follows 'self.generate_object_center' when 'self.generate_object_center' change?
55
+
56
+ if self.load_camera_file:
57
+ self.data_aug_conf = params["fusion"]["args"]["data_aug_conf"]
58
+
59
+ # by default, we load lidar, camera and metadata. But users may
60
+ # define additional inputs/tasks
61
+ self.add_data_extension = \
62
+ params['add_data_extension'] if 'add_data_extension' \
63
+ in params else []
64
+
65
+ if "noise_setting" not in self.params:
66
+ self.params['noise_setting'] = OrderedDict()
67
+ self.params['noise_setting']['add_noise'] = False
68
+
69
+ # first load all paths of different scenarios
70
+ scenario_folders = sorted([os.path.join(root_dir, x)
71
+ for x in os.listdir(root_dir) if
72
+ os.path.isdir(os.path.join(root_dir, x))])
73
+
74
+ self.scenario_folders = scenario_folders
75
+
76
+ self.reinitialize()
77
+
78
+
79
+ def reinitialize(self):
80
+ # Structure: {scenario_id : {cav_1 : {timestamp1 : {yaml: path,
81
+ # lidar: path, cameras:list of path}}}}
82
+ self.scenario_database = OrderedDict()
83
+ self.len_record = []
84
+
85
+ # loop over all scenarios
86
+ for (i, scenario_folder) in enumerate(self.scenario_folders):
87
+ self.scenario_database.update({i: OrderedDict()})
88
+
89
+ # at least 1 cav should show up
90
+ if self.train:
91
+ cav_list = [x for x in os.listdir(scenario_folder)
92
+ if os.path.isdir(
93
+ os.path.join(scenario_folder, x))]
94
+ # cav_list = sorted(cav_list)
95
+ random.shuffle(cav_list)
96
+ else:
97
+ cav_list = sorted([x for x in os.listdir(scenario_folder)
98
+ if os.path.isdir(
99
+ os.path.join(scenario_folder, x))])
100
+ assert len(cav_list) > 0
101
+
102
+ """
103
+ roadside unit data's id is always negative, so here we want to
104
+ make sure they will be in the end of the list as they shouldn't
105
+ be ego vehicle.
106
+ """
107
+ if int(cav_list[0]) < 0:
108
+ cav_list = cav_list[1:] + [cav_list[0]]
109
+
110
+ """
111
+ make the first cav to be ego modality
112
+ """
113
+ if getattr(self, "heterogeneous", False):
114
+ scenario_name = scenario_folder.split("/")[-1]
115
+ cav_list = self.adaptor.reorder_cav_list(cav_list, scenario_name)
116
+
117
+
118
+ # loop over all CAV data
119
+ for (j, cav_id) in enumerate(cav_list):
120
+ if j > self.max_cav - 1:
121
+ print('too many cavs reinitialize')
122
+ break
123
+ self.scenario_database[i][cav_id] = OrderedDict()
124
+
125
+ # save all yaml files to the dictionary
126
+ cav_path = os.path.join(scenario_folder, cav_id)
127
+
128
+ yaml_files = \
129
+ sorted([os.path.join(cav_path, x)
130
+ for x in os.listdir(cav_path) if
131
+ x.endswith('.yaml') and 'additional' not in x])
132
+
133
+ # this timestamp is not ready
134
+ yaml_files = [x for x in yaml_files if not ("2021_08_20_21_10_24" in x and "000265" in x)]
135
+
136
+ timestamps = self.extract_timestamps(yaml_files)
137
+
138
+ for timestamp in timestamps:
139
+ self.scenario_database[i][cav_id][timestamp] = \
140
+ OrderedDict()
141
+ yaml_file = os.path.join(cav_path,
142
+ timestamp + '.yaml')
143
+ lidar_file = os.path.join(cav_path,
144
+ timestamp + '.pcd')
145
+ camera_files = self.find_camera_files(cav_path,
146
+ timestamp)
147
+ depth_files = self.find_camera_files(cav_path,
148
+ timestamp, sensor="depth")
149
+
150
+ self.scenario_database[i][cav_id][timestamp]['yaml'] = \
151
+ yaml_file
152
+ self.scenario_database[i][cav_id][timestamp]['lidar'] = \
153
+ lidar_file
154
+ self.scenario_database[i][cav_id][timestamp]['cameras'] = \
155
+ camera_files
156
+ self.scenario_database[i][cav_id][timestamp]['depths'] = \
157
+ depth_files
158
+
159
+ if getattr(self, "heterogeneous", False):
160
+ scenario_name = scenario_folder.split("/")[-1]
161
+
162
+ cav_modality = self.adaptor.reassign_cav_modality(self.modality_assignment[scenario_name][cav_id] , j)
163
+
164
+ self.scenario_database[i][cav_id][timestamp]['modality_name'] = cav_modality
165
+
166
+ self.scenario_database[i][cav_id][timestamp]['lidar'] = \
167
+ self.adaptor.switch_lidar_channels(cav_modality, lidar_file)
168
+
169
+
170
+ # load extra data
171
+ for file_extension in self.add_data_extension:
172
+ file_name = \
173
+ os.path.join(cav_path,
174
+ timestamp + '_' + file_extension)
175
+
176
+ self.scenario_database[i][cav_id][timestamp][
177
+ file_extension] = file_name
178
+
179
+ # Assume all cavs will have the same timestamps length. Thus
180
+ # we only need to calculate for the first vehicle in the
181
+ # scene.
182
+ if j == 0:
183
+ # we regard the agent with the minimum id as the ego
184
+ self.scenario_database[i][cav_id]['ego'] = True
185
+ if not self.len_record:
186
+ self.len_record.append(len(timestamps))
187
+ else:
188
+ prev_last = self.len_record[-1]
189
+ self.len_record.append(prev_last + len(timestamps))
190
+ else:
191
+ self.scenario_database[i][cav_id]['ego'] = False
192
+
193
+
194
+ def retrieve_base_data(self, idx):
195
+ """
196
+ Given the index, return the corresponding data.
197
+
198
+ Parameters
199
+ ----------
200
+ idx : int
201
+ Index given by dataloader.
202
+
203
+ Returns
204
+ -------
205
+ data : dict
206
+ The dictionary contains loaded yaml params and lidar data for
207
+ each cav.
208
+ """
209
+ # we loop the accumulated length list to see get the scenario index
210
+ scenario_index = 0
211
+ for i, ele in enumerate(self.len_record):
212
+ if idx < ele:
213
+ scenario_index = i
214
+ break
215
+ scenario_database = self.scenario_database[scenario_index]
216
+
217
+ # check the timestamp index
218
+ timestamp_index = idx if scenario_index == 0 else \
219
+ idx - self.len_record[scenario_index - 1]
220
+ # retrieve the corresponding timestamp key
221
+ timestamp_key = self.return_timestamp_key(scenario_database,
222
+ timestamp_index)
223
+ data = OrderedDict()
224
+ # load files for all CAVs
225
+ for cav_id, cav_content in scenario_database.items():
226
+ data[cav_id] = OrderedDict()
227
+ data[cav_id]['ego'] = cav_content['ego']
228
+
229
+ # load param file: json is faster than yaml
230
+ json_file = cav_content[timestamp_key]['yaml'].replace("yaml", "json")
231
+ if os.path.exists(json_file):
232
+ with open(json_file, "r") as f:
233
+ data[cav_id]['params'] = json.load(f)
234
+ else:
235
+ data[cav_id]['params'] = \
236
+ load_yaml(cav_content[timestamp_key]['yaml'])
237
+
238
+ # load camera file: hdf5 is faster than png
239
+ hdf5_file = cav_content[timestamp_key]['cameras'][0].replace("camera0.png", "imgs.hdf5")
240
+
241
+ if os.path.exists(hdf5_file):
242
+ with h5py.File(hdf5_file, "r") as f:
243
+ data[cav_id]['camera_data'] = []
244
+ data[cav_id]['depth_data'] = []
245
+ for i in range(4):
246
+ data[cav_id]['camera_data'].append(Image.fromarray(f[f'camera{i}'][()]))
247
+ data[cav_id]['depth_data'].append(Image.fromarray(f[f'depth{i}'][()]))
248
+ else:
249
+ if self.load_camera_file:
250
+ data[cav_id]['camera_data'] = \
251
+ load_camera_data(cav_content[timestamp_key]['cameras'])
252
+ if self.load_depth_file:
253
+ data[cav_id]['depth_data'] = \
254
+ load_camera_data(cav_content[timestamp_key]['depths'])
255
+
256
+ # load lidar file
257
+ if self.load_lidar_file or self.visualize:
258
+ data[cav_id]['lidar_np'] = \
259
+ pcd_utils.pcd_to_np(cav_content[timestamp_key]['lidar'])
260
+
261
+ if getattr(self, "heterogeneous", False):
262
+ data[cav_id]['modality_name'] = cav_content[timestamp_key]['modality_name']
263
+
264
+ for file_extension in self.add_data_extension:
265
+ # if not find in the current directory
266
+ # go to additional folder
267
+ if not os.path.exists(cav_content[timestamp_key][file_extension]):
268
+ cav_content[timestamp_key][file_extension] = cav_content[timestamp_key][file_extension].replace("train","additional/train")
269
+ cav_content[timestamp_key][file_extension] = cav_content[timestamp_key][file_extension].replace("validate","additional/validate")
270
+ cav_content[timestamp_key][file_extension] = cav_content[timestamp_key][file_extension].replace("test","additional/test")
271
+
272
+ if '.yaml' in file_extension:
273
+ data[cav_id][file_extension] = \
274
+ load_yaml(cav_content[timestamp_key][file_extension])
275
+ else:
276
+ data[cav_id][file_extension] = \
277
+ cv2.imread(cav_content[timestamp_key][file_extension])
278
+
279
+
280
+ return data
281
+
282
+ def __len__(self):
283
+ return self.len_record[-1]
284
+
285
+ def __getitem__(self, idx):
286
+ """
287
+ Abstract method, needs to be define by the children class.
288
+ """
289
+ pass
290
+
291
+ @staticmethod
292
+ def extract_timestamps(yaml_files):
293
+ """
294
+ Given the list of the yaml files, extract the mocked timestamps.
295
+
296
+ Parameters
297
+ ----------
298
+ yaml_files : list
299
+ The full path of all yaml files of ego vehicle
300
+
301
+ Returns
302
+ -------
303
+ timestamps : list
304
+ The list containing timestamps only.
305
+ """
306
+ timestamps = []
307
+
308
+ for file in yaml_files:
309
+ res = file.split('/')[-1]
310
+
311
+ timestamp = res.replace('.yaml', '')
312
+ timestamps.append(timestamp)
313
+
314
+ return timestamps
315
+
316
+ @staticmethod
317
+ def return_timestamp_key(scenario_database, timestamp_index):
318
+ """
319
+ Given the timestamp index, return the correct timestamp key, e.g.
320
+ 2 --> '000078'.
321
+
322
+ Parameters
323
+ ----------
324
+ scenario_database : OrderedDict
325
+ The dictionary contains all contents in the current scenario.
326
+
327
+ timestamp_index : int
328
+ The index for timestamp.
329
+
330
+ Returns
331
+ -------
332
+ timestamp_key : str
333
+ The timestamp key saved in the cav dictionary.
334
+ """
335
+ # get all timestamp keys
336
+ timestamp_keys = list(scenario_database.items())[0][1]
337
+ # retrieve the correct index
338
+ timestamp_key = list(timestamp_keys.items())[timestamp_index][0]
339
+
340
+ return timestamp_key
341
+
342
+ @staticmethod
343
+ def find_camera_files(cav_path, timestamp, sensor="camera"):
344
+ """
345
+ Retrieve the paths to all camera files.
346
+
347
+ Parameters
348
+ ----------
349
+ cav_path : str
350
+ The full file path of current cav.
351
+
352
+ timestamp : str
353
+ Current timestamp
354
+
355
+ sensor : str
356
+ "camera" or "depth"
357
+
358
+ Returns
359
+ -------
360
+ camera_files : list
361
+ The list containing all camera png file paths.
362
+ """
363
+ camera0_file = os.path.join(cav_path,
364
+ timestamp + f'_{sensor}0.png')
365
+ camera1_file = os.path.join(cav_path,
366
+ timestamp + f'_{sensor}1.png')
367
+ camera2_file = os.path.join(cav_path,
368
+ timestamp + f'_{sensor}2.png')
369
+ camera3_file = os.path.join(cav_path,
370
+ timestamp + f'_{sensor}3.png')
371
+ return [camera0_file, camera1_file, camera2_file, camera3_file]
372
+
373
+
374
+ def augment(self, lidar_np, object_bbx_center, object_bbx_mask):
375
+ """
376
+ Given the raw point cloud, augment by flipping and rotation.
377
+
378
+ Parameters
379
+ ----------
380
+ lidar_np : np.ndarray
381
+ (n, 4) shape
382
+
383
+ object_bbx_center : np.ndarray
384
+ (n, 7) shape to represent bbx's x, y, z, h, w, l, yaw
385
+
386
+ object_bbx_mask : np.ndarray
387
+ Indicate which elements in object_bbx_center are padded.
388
+ """
389
+ tmp_dict = {'lidar_np': lidar_np,
390
+ 'object_bbx_center': object_bbx_center,
391
+ 'object_bbx_mask': object_bbx_mask}
392
+ tmp_dict = self.data_augmentor.forward(tmp_dict)
393
+
394
+ lidar_np = tmp_dict['lidar_np']
395
+ object_bbx_center = tmp_dict['object_bbx_center']
396
+ object_bbx_mask = tmp_dict['object_bbx_mask']
397
+
398
+ return lidar_np, object_bbx_center, object_bbx_mask
399
+
400
+
401
+ def generate_object_center_lidar(self,
402
+ cav_contents,
403
+ reference_lidar_pose):
404
+ """
405
+ Retrieve all objects in a format of (n, 7), where 7 represents
406
+ x, y, z, l, w, h, yaw or x, y, z, h, w, l, yaw.
407
+ The object_bbx_center is in ego coordinate.
408
+
409
+ Notice: it is a wrap of postprocessor
410
+
411
+ Parameters
412
+ ----------
413
+ cav_contents : list
414
+ List of dictionary, save all cavs' information.
415
+ in fact it is used in get_item_single_car, so the list length is 1
416
+
417
+ reference_lidar_pose : list
418
+ The final target lidar pose with length 6.
419
+
420
+ Returns
421
+ -------
422
+ object_np : np.ndarray
423
+ Shape is (max_num, 7).
424
+ mask : np.ndarray
425
+ Shape is (max_num,).
426
+ object_ids : list
427
+ Length is number of bbx in current sample.
428
+ """
429
+ return self.post_processor.generate_object_center(cav_contents,
430
+ reference_lidar_pose)
431
+
432
+ def generate_object_center_camera(self,
433
+ cav_contents,
434
+ reference_lidar_pose):
435
+ """
436
+ Retrieve all objects in a format of (n, 7), where 7 represents
437
+ x, y, z, l, w, h, yaw or x, y, z, h, w, l, yaw.
438
+ The object_bbx_center is in ego coordinate.
439
+
440
+ Notice: it is a wrap of postprocessor
441
+
442
+ Parameters
443
+ ----------
444
+ cav_contents : list
445
+ List of dictionary, save all cavs' information.
446
+ in fact it is used in get_item_single_car, so the list length is 1
447
+
448
+ reference_lidar_pose : list
449
+ The final target lidar pose with length 6.
450
+
451
+ visibility_map : np.ndarray
452
+ for OPV2V, its 256*256 resolution. 0.39m per pixel. heading up.
453
+
454
+ Returns
455
+ -------
456
+ object_np : np.ndarray
457
+ Shape is (max_num, 7).
458
+ mask : np.ndarray
459
+ Shape is (max_num,).
460
+ object_ids : list
461
+ Length is number of bbx in current sample.
462
+ """
463
+ return self.post_processor.generate_visible_object_center(
464
+ cav_contents, reference_lidar_pose
465
+ )
466
+
467
+ def get_ext_int(self, params, camera_id):
468
+ camera_coords = np.array(params["camera%d" % camera_id]["cords"]).astype(
469
+ np.float32)
470
+ camera_to_lidar = x1_to_x2(
471
+ camera_coords, params["lidar_pose_clean"]
472
+ ).astype(np.float32) # T_LiDAR_camera
473
+ camera_to_lidar = camera_to_lidar @ np.array(
474
+ [[0, 0, 1, 0], [1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]],
475
+ dtype=np.float32) # UE4 coord to opencv coord
476
+ camera_intrinsic = np.array(params["camera%d" % camera_id]["intrinsic"]).astype(
477
+ np.float32
478
+ )
479
+ return camera_to_lidar, camera_intrinsic
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/v2xset_basedataset.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from opencood.data_utils.datasets.basedataset.opv2v_basedataset import OPV2VBaseDataset
2
+
3
+ # All the same as OPV2V
4
+ class V2XSETBaseDataset(OPV2VBaseDataset):
5
+ def __init__(self, params, visulize, train=True):
6
+ super().__init__(params, visulize, train)
7
+
8
+ if self.load_camera_file is True: # '2021_09_09_13_20_58'. This scenario has only 3 camera files?
9
+ scenario_folders_new = [x for x in self.scenario_folders if '2021_09_09_13_20_58' not in x]
10
+ self.scenario_folders = scenario_folders_new
11
+ self.reinitialize()
12
+
13
+
14
+ def generate_object_center_camera(self,
15
+ cav_contents,
16
+ reference_lidar_pose):
17
+ """
18
+ Since V2XSet has not release bev_visiblity map, we can only filter object by range.
19
+
20
+ Suppose the detection range of camera is within 50m
21
+ """
22
+ return self.post_processor.generate_object_center_v2xset_camera(
23
+ cav_contents, reference_lidar_pose
24
+ )
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/v2xsim_basedataset.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Yangheng Zhao <[email protected]>
2
+ import os
3
+ import pickle
4
+ from collections import OrderedDict
5
+ from typing import Dict
6
+ from abc import abstractmethod
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+ from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
12
+ from opencood.utils.common_utils import read_json
13
+ from opencood.utils.transformation_utils import tfm_to_pose
14
+ from opencood.data_utils.pre_processor import build_preprocessor
15
+ from opencood.data_utils.post_processor import build_postprocessor
16
+
17
+ class V2XSIMBaseDataset(Dataset):
18
+ """
19
+ First version.
20
+ Load V2X-sim 2.0 using yifan lu's pickle file.
21
+ Only support LiDAR data.
22
+ """
23
+
24
+ def __init__(self,
25
+ params: Dict,
26
+ visualize: bool = False,
27
+ train: bool = True):
28
+ self.params = params
29
+ self.visualize = visualize
30
+ self.train = train
31
+
32
+ self.pre_processor = build_preprocessor(params["preprocess"], train)
33
+ self.post_processor = build_postprocessor(params["postprocess"], train)
34
+ if 'data_augment' in params: # late and early
35
+ self.data_augmentor = DataAugmentor(params['data_augment'], train)
36
+ else: # intermediate
37
+ self.data_augmentor = None
38
+
39
+ if self.train:
40
+ root_dir = params['root_dir']
41
+ else:
42
+ root_dir = params['validate_dir']
43
+ self.root_dir = root_dir
44
+
45
+ print("Dataset dir:", root_dir)
46
+
47
+ if 'train_params' not in params or \
48
+ 'max_cav' not in params['train_params']:
49
+ self.max_cav = 5
50
+ else:
51
+ self.max_cav = params['train_params']['max_cav']
52
+
53
+ self.load_lidar_file = True if 'lidar' in params['input_source'] or self.visualize else False
54
+ self.load_camera_file = True if 'camera' in params['input_source'] else False
55
+ self.load_depth_file = True if 'depth' in params['input_source'] else False
56
+
57
+ self.label_type = params['label_type'] # 'lidar' or 'camera'
58
+ assert self.label_type in ['lidar', 'camera']
59
+
60
+ self.generate_object_center = self.generate_object_center_lidar if self.label_type == "lidar" \
61
+ else self.generate_object_center_camera
62
+ self.generate_object_center_single = self.generate_object_center
63
+
64
+ self.add_data_extension = \
65
+ params['add_data_extension'] if 'add_data_extension' \
66
+ in params else []
67
+
68
+ if "noise_setting" not in self.params:
69
+ self.params['noise_setting'] = OrderedDict()
70
+ self.params['noise_setting']['add_noise'] = False
71
+
72
+ with open(self.root_dir, 'rb') as f:
73
+ dataset_info = pickle.load(f)
74
+ self.dataset_info_pkl = dataset_info
75
+
76
+ # TODO param: one as ego or all as ego?
77
+ self.ego_mode = 'one' # "all"
78
+
79
+ self.reinitialize()
80
+
81
+ def reinitialize(self):
82
+ self.scene_database = OrderedDict()
83
+ if self.ego_mode == 'one':
84
+ self.len_record = len(self.dataset_info_pkl)
85
+ else:
86
+ raise NotImplementedError(self.ego_mode)
87
+
88
+ for i, scene_info in enumerate(self.dataset_info_pkl):
89
+ self.scene_database.update({i: OrderedDict()})
90
+ cav_num = scene_info['agent_num']
91
+ assert cav_num > 0
92
+
93
+ if self.train:
94
+ cav_ids = 1 + np.random.permutation(cav_num)
95
+ else:
96
+ cav_ids = list(range(1, cav_num + 1))
97
+
98
+
99
+ for j, cav_id in enumerate(cav_ids):
100
+ if j > self.max_cav - 1:
101
+ print('too many cavs reinitialize')
102
+ break
103
+
104
+ self.scene_database[i][cav_id] = OrderedDict()
105
+
106
+ self.scene_database[i][cav_id]['ego'] = j==0
107
+
108
+ self.scene_database[i][cav_id]['lidar'] = scene_info[f'lidar_path_{cav_id}']
109
+ # need to delete this line is running in /GPFS
110
+ self.scene_database[i][cav_id]['lidar'] = \
111
+ self.scene_database[i][cav_id]['lidar'].replace("/GPFS/rhome/yifanlu/workspace/dataset/v2xsim2-complete", "dataset/V2X-Sim-2.0")
112
+
113
+ self.scene_database[i][cav_id]['params'] = OrderedDict()
114
+ self.scene_database[i][cav_id][
115
+ 'params']['lidar_pose'] = tfm_to_pose(
116
+ scene_info[f"lidar_pose_{cav_id}"]
117
+ ) # [x, y, z, roll, pitch, yaw]
118
+ self.scene_database[i][cav_id]['params'][
119
+ 'vehicles'] = scene_info[f'labels_{cav_id}'][
120
+ 'gt_boxes_global']
121
+ self.scene_database[i][cav_id]['params'][
122
+ 'object_ids'] = scene_info[f'labels_{cav_id}'][
123
+ 'gt_object_ids'].tolist()
124
+
125
+ def __len__(self) -> int:
126
+ return self.len_record
127
+
128
+ @abstractmethod
129
+ def __getitem__(self, index):
130
+ pass
131
+
132
+ def retrieve_base_data(self, idx):
133
+ """
134
+ Given the index, return the corresponding data.
135
+
136
+ Parameters
137
+ ----------
138
+ idx : int
139
+ Index given by dataloader.
140
+
141
+ Returns
142
+ -------
143
+ data : dict
144
+ The dictionary contains loaded yaml params and lidar data for
145
+ each cav.
146
+ """
147
+
148
+ data = OrderedDict()
149
+ # {
150
+ # 'cav_id0':{
151
+ # 'ego': bool,
152
+ # 'params': {
153
+ # 'lidar_pose': [x, y, z, roll, pitch, yaw],
154
+ # 'vehicles':{
155
+ # 'id': {'angle', 'center', 'extent', 'location'},
156
+ # ...
157
+ # }
158
+ # },# 包含agent位置信息和object信息
159
+ # 'camera_data':,
160
+ # 'depth_data':,
161
+ # 'lidar_np':,
162
+ # ...
163
+ # }
164
+ # 'cav_id1': ,
165
+ # ...
166
+ # }
167
+ scene = self.scene_database[idx]
168
+ for cav_id, cav_content in scene.items():
169
+ data[f'{cav_id}'] = OrderedDict()
170
+ data[f'{cav_id}']['ego'] = cav_content['ego']
171
+
172
+ data[f'{cav_id}']['params'] = cav_content['params']
173
+
174
+ # load the corresponding data into the dictionary
175
+ nbr_dims = 4 # x,y,z,intensity
176
+ scan = np.fromfile(cav_content['lidar'], dtype='float32')
177
+ points = scan.reshape((-1, 5))[:, :nbr_dims]
178
+ data[f'{cav_id}']['lidar_np'] = points
179
+
180
+ return data
181
+
182
+ def generate_object_center_lidar(self, cav_contents, reference_lidar_pose):
183
+ """
184
+ Retrieve all objects in a format of (n, 7), where 7 represents
185
+ x, y, z, l, w, h, yaw or x, y, z, h, w, l, yaw.
186
+
187
+ Notice: it is a wrap of postprocessor function
188
+
189
+ Parameters
190
+ ----------
191
+ cav_contents : list
192
+ List of dictionary, save all cavs' information.
193
+ in fact it is used in get_item_single_car, so the list length is 1
194
+
195
+ reference_lidar_pose : list
196
+ The final target lidar pose with length 6.
197
+
198
+ Returns
199
+ -------
200
+ object_np : np.ndarray
201
+ Shape is (max_num, 7).
202
+ mask : np.ndarray
203
+ Shape is (max_num,).
204
+ object_ids : list
205
+ Length is number of bbx in current sample.
206
+ """
207
+
208
+ return self.post_processor.generate_object_center_v2x(
209
+ cav_contents, reference_lidar_pose)
210
+
211
+ def generate_object_center_camera(self, cav_contents, reference_lidar_pose):
212
+ raise NotImplementedError()
213
+
214
+ def augment(self, lidar_np, object_bbx_center, object_bbx_mask):
215
+ """
216
+ Given the raw point cloud, augment by flipping and rotation.
217
+
218
+ Parameters
219
+ ----------
220
+ lidar_np : np.ndarray
221
+ (n, 4) shape
222
+
223
+ object_bbx_center : np.ndarray
224
+ (n, 7) shape to represent bbx's x, y, z, h, w, l, yaw
225
+
226
+ object_bbx_mask : np.ndarray
227
+ Indicate which elements in object_bbx_center are padded.
228
+ """
229
+ tmp_dict = {'lidar_np': lidar_np,
230
+ 'object_bbx_center': object_bbx_center,
231
+ 'object_bbx_mask': object_bbx_mask}
232
+ tmp_dict = self.data_augmentor.forward(tmp_dict)
233
+
234
+ lidar_np = tmp_dict['lidar_np']
235
+ object_bbx_center = tmp_dict['object_bbx_center']
236
+ object_bbx_mask = tmp_dict['object_bbx_mask']
237
+
238
+ return lidar_np, object_bbx_center, object_bbx_mask
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/basedataset/v2xverse_basedataset.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from collections import OrderedDict
4
+ import cv2
5
+ import h5py
6
+ import torch
7
+ import torchvision
8
+ import numpy as np
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image
11
+ import json
12
+ import random
13
+ import re
14
+ import math
15
+
16
+ import logging
17
+ _logger = logging.getLogger(__name__)
18
+
19
+ import opencood.utils.pcd_utils as pcd_utils
20
+ from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
21
+ from opencood.hypes_yaml.yaml_utils import load_yaml
22
+ from opencood.utils.camera_utils import load_camera_data
23
+ from opencood.utils.transformation_utils import x1_to_x2
24
+ from opencood.data_utils.pre_processor import build_preprocessor
25
+ from opencood.data_utils.post_processor import build_postprocessor
26
+
27
+
28
+ class V2XVERSEBaseDataset(Dataset):
29
+ def __init__(self, params, visualize, train=True):
30
+ self.params = params
31
+ self.visualize = visualize
32
+ self.train = train
33
+
34
+ self.pre_processor = build_preprocessor(params["preprocess"], train)
35
+ self.post_processor = build_postprocessor(params["postprocess"], train)
36
+ self.data_augmentor = DataAugmentor(params['data_augment'],
37
+ train)
38
+
39
+ self.frame_gap = params.get('frame_gap',200)
40
+ self.time_delay = params.get('time_delay',0)
41
+
42
+ if 'target_assigner_config' in self.params['loss']['args']:
43
+ self.det_range = self.params['loss']['args']['target_assigner_config']['cav_lidar_range'] # [-36, -36, -22, 36, 36, 14]
44
+ else:
45
+ self.det_range = [-36, -36, -22, 36, 36, 14]
46
+
47
+ if self.time_delay % self.frame_gap != 0:
48
+ print("Time delay of v2xverse dataset should be a multiple of frame_gap !")
49
+ self.frame_delay = int(self.time_delay / self.frame_gap)
50
+ print(f'*** time_delay = {self.time_delay} ***')
51
+
52
+ self.test_flag = False
53
+ if self.train:
54
+ root_dir = params['root_dir']
55
+ towns = [1,2,3,4,6]
56
+ elif not visualize:
57
+ root_dir = params['validate_dir']
58
+ towns = [7,10] # [6,7,8,9,10]
59
+ else:
60
+ root_dir = params['test_dir']
61
+ towns = [5]
62
+ self.test_flag = True
63
+ self.root_dir = root_dir
64
+ self.clock = 0
65
+
66
+ print("Dataset dir:", root_dir)
67
+
68
+ if 'train_params' not in params or \
69
+ 'max_cav' not in params['train_params']:
70
+ self.max_cav = 5
71
+ else:
72
+ self.max_cav = params['train_params']['max_cav']
73
+
74
+ self.load_lidar_file = True if 'lidar' in params['input_source'] or self.visualize else False
75
+ self.load_camera_file = True if 'camera' in params['input_source'] else False
76
+ self.load_depth_file = True if 'depth' in params['input_source'] else False
77
+
78
+ self.label_type = params['label_type'] # 'lidar' or 'camera'
79
+ self.generate_object_center = self.generate_object_center_lidar if self.label_type == "lidar" \
80
+ else self.generate_object_center_camera
81
+ self.generate_object_center_single = self.generate_object_center # will it follows 'self.generate_object_center' when 'self.generate_object_center' change?
82
+
83
+ if self.load_camera_file:
84
+ self.data_aug_conf = params["fusion"]["args"]["data_aug_conf"]
85
+
86
+ # by default, we load lidar, camera and metadata. But users may
87
+ # define additional inputs/tasks
88
+ self.add_data_extension = \
89
+ params['add_data_extension'] if 'add_data_extension' \
90
+ in params else []
91
+
92
+ if "noise_setting" not in self.params:
93
+ self.params['noise_setting'] = OrderedDict()
94
+ self.params['noise_setting']['add_noise'] = False
95
+
96
+ if root_dir is None:
97
+ print('Not loading from an existing dataset!')
98
+ return
99
+ if not os.path.exists(root_dir):
100
+ print('Dataset path do not exists!')
101
+ return
102
+
103
+ # first load all paths of different scenarios
104
+ scenario_folders = sorted([os.path.join(root_dir, x)
105
+ for x in os.listdir(root_dir) if
106
+ os.path.isdir(os.path.join(root_dir, x))])
107
+ self.scenario_folders = scenario_folders
108
+
109
+ #################################
110
+ ## v2xverse data load
111
+ #################################
112
+
113
+ self.rsu_change_frame = 25
114
+ self.route_frames = []
115
+
116
+ data_index_name = 'dataset_index.txt'
117
+ if 'index_file' in self.params:
118
+ data_index_name = self.params['index_file'] + '.txt'
119
+ print('data_index_name:', data_index_name)
120
+ dataset_indexs = self._load_text(data_index_name).split('\n')
121
+
122
+ filter_file = None
123
+ if 'filte_danger' in self.params:
124
+ if os.path.exists(os.path.join(self.root_dir,self.params['filte_danger'])):
125
+ filter_file = self._load_json(self.params['filte_danger'])
126
+
127
+ weathers = [0,1,2,3,4,5,6,7,8,9,10]
128
+ pattern = re.compile('weather-(\d+).*town(\d\d)')
129
+ for line in dataset_indexs:
130
+ if len(line.split()) != 3:
131
+ continue
132
+ path, frames, egos = line.split()
133
+ route_path = os.path.join(self.root_dir, path)
134
+ frames = int(frames)
135
+ res = pattern.findall(path)
136
+ if len(res) != 1:
137
+ continue
138
+ weather = int(res[0][0])
139
+ town = int(res[0][1])
140
+ if weather not in weathers or town not in towns:
141
+ continue
142
+
143
+ files = os.listdir(route_path)
144
+ ego_files = [file for file in files if file.startswith('ego')]
145
+ rsu_files = [file for file in files if file.startswith('rsu')]
146
+
147
+ # recompute rsu change frames
148
+ file_len_list = []
149
+ if len(rsu_files) > 0:
150
+ for rsu_file in ['rsu_1000', 'rsu_1001']:
151
+ if rsu_file in rsu_files:
152
+ rsu_frame_len = len(os.listdir(os.path.join(route_path,rsu_file,'measurements')))
153
+ file_len_list.append(rsu_frame_len)
154
+ self.rsu_change_frame = max(file_len_list) + 1
155
+
156
+ for j, file in enumerate(ego_files):
157
+ ego_path = os.path.join(path, file)
158
+ others_list = ego_files[:j]+ego_files[j+1:]
159
+ others_path_list = []
160
+ for others in others_list:
161
+ others_path_list.append(os.path.join(path, others))
162
+
163
+ for i in range(frames):
164
+ # reduce the ratio of frames not at junction
165
+ if filter_file is not None:
166
+ danger_frame_flag = False
167
+ for route_id in filter_file:
168
+ if route_path.endswith(filter_file[route_id]['sub_path']):
169
+ for junction_range in filter_file[route_id]['selected_frames'][file]:
170
+ if i > junction_range[0] and i < junction_range[1]+15:
171
+ danger_frame_flag = True
172
+ if (not danger_frame_flag):
173
+ continue
174
+ scene_dict = {}
175
+ scene_dict['ego'] = ego_path
176
+ scene_dict['other_egos'] = others_path_list
177
+ scene_dict['num_car'] = len(ego_files)
178
+ scene_dict['rsu'] = []
179
+ # order of rsu
180
+ if i%self.rsu_change_frame != 0 and len(rsu_files)>0:
181
+ order = int(i/self.rsu_change_frame)+1 # int(i/10)+1
182
+ rsu_path = 'rsu_{}00{}'.format(order, ego_path[-1])
183
+ if True: # os.path.exists(os.path.join(route_path, rsu_path,'measurements','{}.json'.format(str(i).zfill(4)))):
184
+ scene_dict['rsu'].append(os.path.join(path, rsu_path))
185
+
186
+ self.route_frames.append((scene_dict, i)) # (scene_dict, i)
187
+ self.label_mode = self.params.get('label_mode', 'v2xverse')
188
+ self.first_det = False
189
+ print("Sub route dir nums: %d" % len(self.route_frames))
190
+
191
+ def _load_text(self, path):
192
+ text = open(os.path.join(self.root_dir,path), 'r').read()
193
+ return text
194
+
195
+ def _load_image(self, path):
196
+ trans_totensor = torchvision.transforms.ToTensor()
197
+ trans_toPIL = torchvision.transforms.ToPILImage()
198
+ try:
199
+ img = Image.open(os.path.join(self.root_dir,path))
200
+ img_tensor = trans_totensor(img)
201
+ img_PIL = trans_toPIL(img_tensor)
202
+ except Exception as e:
203
+ _logger.info(path)
204
+ n = path[-8:-4]
205
+ new_path = path[:-8] + "%04d.jpg" % (int(n) - 1)
206
+ img = Image.open(os.path.join(self.root_dir,new_path))
207
+ img_tensor = trans_totensor(img)
208
+ img_PIL = trans_toPIL(img_tensor)
209
+ return img_PIL
210
+
211
+ def _load_json(self, path):
212
+ try:
213
+ json_value = json.load(open(os.path.join(self.root_dir,path)))
214
+ except Exception as e:
215
+ _logger.info(path)
216
+ n = path[-9:-5]
217
+ new_path = path[:-9] + "%04d.json" % (int(n) - 1)
218
+ json_value = json.load(open(os.path.join(self.root_dir,new_path)))
219
+ return json_value
220
+
221
+ def _load_npy(self, path):
222
+ try:
223
+ array = np.load(os.path.join(self.root_dir,path), allow_pickle=True)
224
+ except Exception as e:
225
+ _logger.info(path)
226
+ n = path[-8:-4]
227
+ new_path = path[:-8] + "%04d.npy" % (int(n) - 1)
228
+ array = np.load(os.path.join(self.root_dir,new_path), allow_pickle=True)
229
+ return array
230
+
231
+ def get_one_record(self, route_dir, frame_id, agent='ego', visible_actors=None, tpe='all', extra_source=None):
232
+ '''
233
+ Parameters
234
+ ----------
235
+ scene_dict: str, index given by dataloader.
236
+ frame_id: int, frame id.
237
+
238
+ Returns
239
+ -------
240
+ data:
241
+ structure: dict{
242
+ ####################
243
+ # input to the model
244
+ ####################
245
+ 'agent': 'ego' or 'other_ego', # whether it is the ego car
246
+ 'rgb_[direction]': torch.Tenser, # direction in [left, right, center], shape (3, 128, 128)
247
+ 'rgb': torch.Tensor, front rgb image , # shape (3, 224, 224)
248
+ 'measurements': torch.Tensor, size [7]: the first 6 dims is the onehot vector of command, and the last dim is car speed
249
+ 'command': int, 0-5, discrete command signal 0:left, 1:right, 2:straight,
250
+ # 3: lane follow, 4:lane change left, 5: lane change right
251
+ 'pose': np.array, shape(3,), lidar pose[gps_x, gps_y, theta]
252
+ 'detmap_pose': pose for density map
253
+ 'target_point': torch.Tensor, size[2], (x,y) coordinate in the left hand coordinate system,
254
+ where X-axis towards right side of the car
255
+ 'lidar': np.ndarray, # shape (3, 224, 224), 2D projection of lidar, range x:[-28m, 28m], y:[-28m,28m]
256
+ in the right hand coordinate system with X-axis towards left of car
257
+ ####################
258
+ # target of model
259
+ ####################
260
+ 'img_traffic': not yet used in model
261
+ 'command_waypoints': torch.Tensor, size[10,2], 10 (x,y) coordinates in the same coordinate system with target point
262
+ 'is_junction': int, 0 or 1, 1 means the car is at junction
263
+ 'traffic_light_state': int, 0 or 1
264
+ 'det_data': np.array, (400,7), flattened density map, 7 feature dims corresponds to
265
+ [prob_obj, box bias_X, box bias_Y, box_orientation, l, w, speed]
266
+ 'img_traj': not yet used in model
267
+ 'stop_sign': int, 0 or 1, exist of stop sign
268
+ },
269
+ '''
270
+
271
+ output_record = OrderedDict()
272
+
273
+ if agent == 'ego':
274
+ output_record['ego'] = True
275
+ else:
276
+ output_record['ego'] = False
277
+
278
+ BEV = None
279
+
280
+ if route_dir is not None:
281
+ measurements = self._load_json(os.path.join(route_dir, "measurements", "%04d.json" % frame_id))
282
+ actors_data = self._load_json(os.path.join(route_dir, "actors_data", "%04d.json" % frame_id))
283
+ elif extra_source is not None:
284
+ if 'actors_data' in extra_source:
285
+ actors_data = extra_source['actors_data']
286
+ else:
287
+ actors_data = {}
288
+ measurements = extra_source['measurements']
289
+
290
+ ego_loc = np.array([measurements['x'], measurements['y']])
291
+ output_record['params'] = {}
292
+
293
+ cam_list = ['front','right','left','rear']
294
+ cam_angle_list = [0, 60, -60, 180]
295
+ for cam_id in range(4):
296
+ output_record['params']['camera{}'.format(cam_id)] = {}
297
+ output_record['params']['camera{}'.format(cam_id)]['cords'] = [measurements['x'], measurements['y'], 1.0,\
298
+ 0,measurements['theta']/np.pi*180+cam_angle_list[cam_id],0]
299
+ output_record['params']['camera{}'.format(cam_id)]['extrinsic'] = measurements['camera_{}_extrinsics'.format(cam_list[cam_id])]
300
+ output_record['params']['camera{}'.format(cam_id)]['intrinsic'] = measurements['camera_{}_intrinsics'.format(cam_list[cam_id])]
301
+
302
+ if 'speed' in measurements:
303
+ output_record['params']['ego_speed'] = measurements['speed']*3.6
304
+ else:
305
+ output_record['params']['ego_speed'] = 0
306
+
307
+ output_record['params']['lidar_pose'] = \
308
+ [measurements['lidar_pose_x'], measurements['lidar_pose_y'], 0, \
309
+ 0,measurements['theta']/np.pi*180-90,0]
310
+ self.distance_to_map_center = (self.det_range[3]-self.det_range[0])/2+self.det_range[0]
311
+ output_record['params']['map_pose'] = \
312
+ [measurements['lidar_pose_x'] + self.distance_to_map_center*np.cos(measurements["theta"]-np.pi/2),
313
+ measurements['lidar_pose_y'] + self.distance_to_map_center*np.sin(measurements["theta"]-np.pi/2), 0, \
314
+ 0,measurements['theta']/np.pi*180-90,0]
315
+ detmap_pose_x = measurements['lidar_pose_x'] + self.distance_to_map_center*np.cos(measurements["theta"]-np.pi/2)
316
+ detmap_pose_y = measurements['lidar_pose_y'] + self.distance_to_map_center*np.sin(measurements["theta"]-np.pi/2)
317
+ detmap_theta = measurements["theta"] + np.pi/2
318
+ output_record['detmap_pose'] = np.array([-detmap_pose_y, detmap_pose_x, detmap_theta])
319
+ output_record['params']['lidar_pose_clean'] = output_record['params']['lidar_pose']
320
+ output_record['params']['plan_trajectory'] = []
321
+ output_record['params']['true_ego_pos'] = \
322
+ [measurements['lidar_pose_x'], measurements['lidar_pose_y'], 0, \
323
+ 0,measurements['theta']/np.pi*180,0]
324
+ output_record['params']['predicted_ego_pos'] = \
325
+ [measurements['lidar_pose_x'], measurements['lidar_pose_y'], 0, \
326
+ 0,measurements['theta']/np.pi*180,0]
327
+
328
+ if tpe == 'all':
329
+ if route_dir is not None:
330
+ lidar = self._load_npy(os.path.join(route_dir, "lidar", "%04d.npy" % frame_id))
331
+ output_record['rgb_front'] = self._load_image(os.path.join(route_dir, "rgb_front", "%04d.jpg" % frame_id))
332
+ output_record['rgb_left'] = self._load_image(os.path.join(route_dir, "rgb_left", "%04d.jpg" % frame_id))
333
+ output_record['rgb_right'] = self._load_image(os.path.join(route_dir, "rgb_right", "%04d.jpg" % frame_id))
334
+ output_record['rgb_rear'] = self._load_image(os.path.join(route_dir, "rgb_rear", "%04d.jpg" % frame_id))
335
+ if agent != 'rsu':
336
+ BEV = self._load_image(os.path.join(route_dir, "birdview", "%04d.jpg" % frame_id))
337
+ elif extra_source is not None:
338
+ lidar = extra_source['lidar']
339
+ if 'rgb_front' in extra_source:
340
+ output_record['rgb_front'] = extra_source['rgb_front']
341
+ output_record['rgb_left'] = extra_source['rgb_left']
342
+ output_record['rgb_right'] = extra_source['rgb_right']
343
+ output_record['rgb_rear'] = extra_source['rgb_rear']
344
+ else:
345
+ output_record['rgb_front'] = None
346
+ output_record['rgb_left'] = None
347
+ output_record['rgb_right'] = None
348
+ output_record['rgb_rear'] = None
349
+ BEV = None
350
+
351
+ output_record['lidar_np'] = lidar
352
+ lidar_transformed = np.zeros((output_record['lidar_np'].shape))
353
+ lidar_transformed[:,0] = output_record['lidar_np'][:,1]
354
+ lidar_transformed[:,1] = -output_record['lidar_np'][:,0]
355
+ lidar_transformed[:,2:] = output_record['lidar_np'][:,2:]
356
+ output_record['lidar_np'] = lidar_transformed.astype(np.float32)
357
+ output_record['lidar_np'][:, 2] += measurements['lidar_pose_z']
358
+
359
+ if visible_actors is not None:
360
+ actors_data = self.filter_actors_data_according_to_visible(actors_data, visible_actors)
361
+
362
+ ################ LSS debug TODO: clean up this function #####################
363
+ if not self.first_det:
364
+ import copy
365
+ if True: # agent=='rsu':
366
+ measurements["affected_light_id"] = -1
367
+ measurements["is_vehicle_present"] = []
368
+ measurements["is_bike_present"] = []
369
+ measurements["is_junction_vehicle_present"] = []
370
+ measurements["is_pedestrian_present"] = []
371
+ measurements["future_waypoints"] = []
372
+ cop3_range = [36,12,12,12, 0.25]
373
+ heatmap = generate_heatmap_multiclass(
374
+ copy.deepcopy(measurements), copy.deepcopy(actors_data), max_distance=36
375
+ )
376
+ self.det_data = (
377
+ generate_det_data_multiclass(
378
+ heatmap, copy.deepcopy(measurements), copy.deepcopy(actors_data), cop3_range
379
+ )
380
+ .reshape(3, int((cop3_range[0]+cop3_range[1])/cop3_range[4]
381
+ *(cop3_range[2]+cop3_range[3])/cop3_range[4]), -1) #(2, H*W,7)
382
+ .astype(np.float32)
383
+ )
384
+ self.first_det = True
385
+ if self.label_mode == 'cop3':
386
+ self.first_det = False
387
+ output_record['det_data'] = self.det_data
388
+ ##############################################################
389
+ if agent == 'rsu' :
390
+ for actor_id in actors_data.keys():
391
+ if actors_data[actor_id]['tpe'] == 0:
392
+ box = actors_data[actor_id]['box']
393
+ if abs(box[0]-0.8214) < 0.01 and abs(box[1]-0.18625) < 0.01 :
394
+ actors_data[actor_id]['tpe'] = 3
395
+
396
+ output_record['params']['vehicles'] = {}
397
+ for actor_id in actors_data.keys():
398
+
399
+ ######################
400
+ ## debug
401
+ ######################
402
+ # if agent == 'ego':
403
+ # continue
404
+
405
+ if tpe in [0, 1, 3]:
406
+ if actors_data[actor_id]['tpe'] != tpe:
407
+ continue
408
+
409
+ # exclude ego car
410
+ loc_actor = np.array(actors_data[actor_id]['loc'][0:2])
411
+ dis = np.linalg.norm(ego_loc - loc_actor)
412
+ if dis < 0.1:
413
+ continue
414
+
415
+ if not ('box' in actors_data[actor_id].keys() and 'ori' in actors_data[actor_id].keys() and 'loc' in actors_data[actor_id].keys()):
416
+ continue
417
+ output_record['params']['vehicles'][actor_id] = {}
418
+ output_record['params']['vehicles'][actor_id]['tpe'] = actors_data[actor_id]['tpe']
419
+ yaw = math.degrees(math.atan(actors_data[actor_id]['ori'][1]/actors_data[actor_id]['ori'][0]))
420
+ pitch = math.degrees(math.asin(actors_data[actor_id]['ori'][2]))
421
+ output_record['params']['vehicles'][actor_id]['angle'] = [0,yaw,pitch]
422
+ output_record['params']['vehicles'][actor_id]['center'] = [0,0,actors_data[actor_id]['box'][2]]
423
+ output_record['params']['vehicles'][actor_id]['extent'] = actors_data[actor_id]['box']
424
+ output_record['params']['vehicles'][actor_id]['location'] = [actors_data[actor_id]['loc'][0],actors_data[actor_id]['loc'][1],0]
425
+ output_record['params']['vehicles'][actor_id]['speed'] = 3.6 * math.sqrt(actors_data[actor_id]['vel'][0]**2+actors_data[actor_id]['vel'][1]**2 )
426
+
427
+ direction_list = ['front','left','right','rear']
428
+ theta_list = [0,-60,60,180]
429
+ dis_list = [0,0,0,-2.6]
430
+ camera_data_list = []
431
+ for i, direction in enumerate(direction_list):
432
+ if 'rgb_{}'.format(direction) in output_record:
433
+ camera_data_list.append(output_record['rgb_{}'.format(direction)])
434
+ dis_to_lidar = dis_list[i]
435
+ output_record['params']['camera{}'.format(i)]['cords'] = \
436
+ [measurements['x'] + dis_to_lidar*np.sin(measurements['theta']), measurements['y'] - dis_to_lidar*np.cos(measurements['theta']), 2.3,\
437
+ 0,measurements['theta']/np.pi*180 - 90 + theta_list[i],0]
438
+ output_record['params']['camera{}'.format(i)]['extrinsic'] = measurements['camera_{}_extrinsics'.format(direction_list[i])]
439
+ output_record['params']['camera{}'.format(i)]['intrinsic'] = measurements['camera_{}_intrinsics'.format(direction_list[i])]
440
+ output_record['camera_data'] = camera_data_list
441
+ bev_visibility_np = 255*np.ones((256,256,3), dtype=np.uint8)
442
+ output_record['bev_visibility.png'] = bev_visibility_np
443
+
444
+ if agent != 'rsu':
445
+ output_record['BEV'] = BEV
446
+ else:
447
+ output_record['BEV'] = None
448
+ return output_record
449
+
450
+ def filter_actors_data_according_to_visible(self, actors_data, visible_actors):
451
+ to_del_id = []
452
+ for actors_id in actors_data.keys():
453
+ if actors_id in visible_actors:
454
+ continue
455
+ to_del_id.append(actors_id)
456
+ for actors_id in to_del_id:
457
+ del actors_data[actors_id]
458
+ return actors_data
459
+
460
+ def get_visible_actors_one_term(self, route_dir, frame_id):
461
+ cur_visible_actors = []
462
+ actors_data = self._load_json(os.path.join(route_dir, "actors_data", "%04d.json" % frame_id))
463
+
464
+ for actors_id in actors_data:
465
+ if actors_data[actors_id]['tpe']==2:
466
+ continue
467
+ if not 'lidar_visible' in actors_data[actors_id]:
468
+ cur_visible_actors.append(actors_id)
469
+ print('Lose of lidar_visible!')
470
+ continue
471
+ if actors_data[actors_id]['lidar_visible']==1:
472
+ cur_visible_actors.append(actors_id)
473
+ return cur_visible_actors
474
+
475
+ def get_visible_actors(self, scene_dict, frame_id):
476
+ visible_actors = {} # id only
477
+ if self.test_flag:
478
+ visible_actors['car_0'] = None
479
+ for i, route_dir in enumerate(scene_dict['other_egos']):
480
+ visible_actors['car_{}'.format(i+1)] = None
481
+ for i, rsu_dir in enumerate(scene_dict['rsu']):
482
+ visible_actors['rsu_{}'.format(i)] = None
483
+ else:
484
+ visible_actors['car_0'] = self.get_visible_actors_one_term(scene_dict['ego'], frame_id)
485
+ if self.params['train_params']['max_cav'] > 1:
486
+ for i, route_dir in enumerate(scene_dict['other_egos']):
487
+ visible_actors['car_{}'.format(i+1)] = self.get_visible_actors_one_term(route_dir, frame_id)
488
+ for i, rsu_dir in enumerate(scene_dict['rsu']):
489
+ visible_actors['rsu_{}'.format(i)] = self.get_visible_actors_one_term(rsu_dir, frame_id)
490
+ for keys in visible_actors:
491
+ visible_actors[keys] = list(set(visible_actors[keys]))
492
+ return visible_actors
493
+
494
+ def retrieve_base_data(self, idx, tpe='all', extra_source=None, data_dir=None):
495
+ if extra_source is None:
496
+ if data_dir is not None:
497
+ scene_dict, frame_id = data_dir
498
+ else:
499
+ scene_dict, frame_id = self.route_frames[idx]
500
+ frame_id_latency = frame_id - self.frame_delay
501
+ visible_actors = None
502
+ visible_actors = self.get_visible_actors(scene_dict, frame_id)
503
+ data = OrderedDict()
504
+ data['car_0'] = self.get_one_record(scene_dict['ego'], frame_id , agent='ego', visible_actors=visible_actors['car_0'], tpe=tpe)
505
+ if self.params['train_params']['max_cav'] > 1:
506
+ for i, route_dir in enumerate(scene_dict['other_egos']):
507
+ try:
508
+ data['car_{}'.format(i+1)] = self.get_one_record(route_dir, frame_id_latency , agent='other_ego', visible_actors=visible_actors['car_{}'.format(i+1)], tpe=tpe)
509
+ except:
510
+ print('load other ego failed')
511
+ continue
512
+ if self.params['train_params']['max_cav'] > 2:
513
+ for i, rsu_dir in enumerate(scene_dict['rsu']):
514
+ try:
515
+ data['rsu_{}'.format(i)] = self.get_one_record(rsu_dir, frame_id_latency, agent='rsu', visible_actors=visible_actors['rsu_{}'.format(i)], tpe=tpe)
516
+ except:
517
+ print('load rsu failed')
518
+ continue
519
+ else:
520
+ data = OrderedDict()
521
+ scene_dict = None
522
+ frame_id = None
523
+ data['car_0'] = self.get_one_record(route_dir=None, frame_id=None , agent='ego', visible_actors=None, tpe=tpe, extra_source=extra_source['car_data'][0])
524
+ if self.params['train_params']['max_cav'] > 1:
525
+ if len(extra_source['car_data']) > 1:
526
+ for i in range(len(extra_source['car_data'])-1):
527
+ data['car_{}'.format(i+1)] = self.get_one_record(route_dir=None, frame_id=None , agent='other_ego', visible_actors=None, tpe=tpe, extra_source=extra_source['car_data'][i+1])
528
+ for i in range(len(extra_source['rsu_data'])):
529
+ data['rsu_{}'.format(i)] = self.get_one_record(route_dir=None, frame_id=None , agent='rsu', visible_actors=None, tpe=tpe, extra_source=extra_source['rsu_data'][i])
530
+ data['car_0']['scene_dict'] = scene_dict
531
+ data['car_0']['frame_id'] = frame_id
532
+ return data
533
+
534
+
535
+ def __len__(self):
536
+ return len(self.route_frames)
537
+
538
+ def __getitem__(self, idx):
539
+ """
540
+ Abstract method, needs to be define by the children class.
541
+ """
542
+ pass
543
+
544
+ @staticmethod
545
+ def extract_timestamps(yaml_files):
546
+ """
547
+ Given the list of the yaml files, extract the mocked timestamps.
548
+
549
+ Parameters
550
+ ----------
551
+ yaml_files : list
552
+ The full path of all yaml files of ego vehicle
553
+
554
+ Returns
555
+ -------
556
+ timestamps : list
557
+ The list containing timestamps only.
558
+ """
559
+ timestamps = []
560
+
561
+ for file in yaml_files:
562
+ res = file.split('/')[-1]
563
+
564
+ timestamp = res.replace('.yaml', '')
565
+ timestamps.append(timestamp)
566
+
567
+ return timestamps
568
+
569
+ @staticmethod
570
+ def return_timestamp_key(scenario_database, timestamp_index):
571
+ """
572
+ Given the timestamp index, return the correct timestamp key, e.g.
573
+ 2 --> '000078'.
574
+
575
+ Parameters
576
+ ----------
577
+ scenario_database : OrderedDict
578
+ The dictionary contains all contents in the current scenario.
579
+
580
+ timestamp_index : int
581
+ The index for timestamp.
582
+
583
+ Returns
584
+ -------
585
+ timestamp_key : str
586
+ The timestamp key saved in the cav dictionary.
587
+ """
588
+ # get all timestamp keys
589
+ timestamp_keys = list(scenario_database.items())[0][1]
590
+ # retrieve the correct index
591
+ timestamp_key = list(timestamp_keys.items())[timestamp_index][0]
592
+
593
+ return timestamp_key
594
+
595
+ @staticmethod
596
+ def find_camera_files(cav_path, timestamp, sensor="camera"):
597
+ """
598
+ Retrieve the paths to all camera files.
599
+
600
+ Parameters
601
+ ----------
602
+ cav_path : str
603
+ The full file path of current cav.
604
+
605
+ timestamp : str
606
+ Current timestamp
607
+
608
+ sensor : str
609
+ "camera" or "depth"
610
+
611
+ Returns
612
+ -------
613
+ camera_files : list
614
+ The list containing all camera png file paths.
615
+ """
616
+ camera0_file = os.path.join(cav_path,
617
+ timestamp + f'_{sensor}0.png')
618
+ camera1_file = os.path.join(cav_path,
619
+ timestamp + f'_{sensor}1.png')
620
+ camera2_file = os.path.join(cav_path,
621
+ timestamp + f'_{sensor}2.png')
622
+ camera3_file = os.path.join(cav_path,
623
+ timestamp + f'_{sensor}3.png')
624
+ return [camera0_file, camera1_file, camera2_file, camera3_file]
625
+
626
+
627
+ def augment(self, lidar_np, object_bbx_center, object_bbx_mask):
628
+ """
629
+ Given the raw point cloud, augment by flipping and rotation.
630
+
631
+ Parameters
632
+ ----------
633
+ lidar_np : np.ndarray
634
+ (n, 4) shape
635
+
636
+ object_bbx_center : np.ndarray
637
+ (n, 7) shape to represent bbx's x, y, z, h, w, l, yaw
638
+
639
+ object_bbx_mask : np.ndarray
640
+ Indicate which elements in object_bbx_center are padded.
641
+ """
642
+ tmp_dict = {'lidar_np': lidar_np,
643
+ 'object_bbx_center': object_bbx_center,
644
+ 'object_bbx_mask': object_bbx_mask}
645
+ tmp_dict = self.data_augmentor.forward(tmp_dict)
646
+
647
+ lidar_np = tmp_dict['lidar_np']
648
+ object_bbx_center = tmp_dict['object_bbx_center']
649
+ object_bbx_mask = tmp_dict['object_bbx_mask']
650
+
651
+ return lidar_np, object_bbx_center, object_bbx_mask
652
+
653
+
654
+ def generate_object_center_lidar(self,
655
+ cav_contents,
656
+ reference_lidar_pose):
657
+ """
658
+ Retrieve all objects in a format of (n, 7), where 7 represents
659
+ x, y, z, l, w, h, yaw or x, y, z, h, w, l, yaw.
660
+ The object_bbx_center is in ego coordinate.
661
+
662
+ Notice: it is a wrap of postprocessor
663
+
664
+ Parameters
665
+ ----------
666
+ cav_contents : list
667
+ List of dictionary, save all cavs' information.
668
+ in fact it is used in get_item_single_car, so the list length is 1
669
+
670
+ reference_lidar_pose : list
671
+ The final target lidar pose with length 6.
672
+
673
+ Returns
674
+ -------
675
+ object_np : np.ndarray
676
+ Shape is (max_num, 7).
677
+ mask : np.ndarray
678
+ Shape is (max_num,).
679
+ object_ids : list
680
+ Length is number of bbx in current sample.
681
+ """
682
+ return self.post_processor.generate_object_center(cav_contents,
683
+ reference_lidar_pose)
684
+
685
+ def generate_object_center_camera(self,
686
+ cav_contents,
687
+ reference_lidar_pose):
688
+ """
689
+ Retrieve all objects in a format of (n, 7), where 7 represents
690
+ x, y, z, l, w, h, yaw or x, y, z, h, w, l, yaw.
691
+ The object_bbx_center is in ego coordinate.
692
+
693
+ Notice: it is a wrap of postprocessor
694
+
695
+ Parameters
696
+ ----------
697
+ cav_contents : list
698
+ List of dictionary, save all cavs' information.
699
+ in fact it is used in get_item_single_car, so the list length is 1
700
+
701
+ reference_lidar_pose : list
702
+ The final target lidar pose with length 6.
703
+
704
+ visibility_map : np.ndarray
705
+ for OPV2V, its 256*256 resolution. 0.39m per pixel. heading up.
706
+
707
+ Returns
708
+ -------
709
+ object_np : np.ndarray
710
+ Shape is (max_num, 7).
711
+ mask : np.ndarray
712
+ Shape is (max_num,).
713
+ object_ids : list
714
+ Length is number of bbx in current sample.
715
+ """
716
+ return self.post_processor.generate_visible_object_center(
717
+ cav_contents, reference_lidar_pose
718
+ )
719
+
720
+ def get_ext_int(self, params, camera_id):
721
+ if self.params['extrinsic'] == 1:
722
+ return self.get_ext_int_1(params, camera_id)
723
+ elif self.params['extrinsic'] == 2:
724
+ return self.get_ext_int_2(params, camera_id)
725
+ def get_ext_int_1(self, params, camera_id):
726
+ camera_coords = np.array(params["camera%d" % camera_id]["cords"]).astype(
727
+ np.float32)
728
+ camera_to_lidar = x1_to_x2(
729
+ camera_coords, params["lidar_pose_clean"]
730
+ ).astype(np.float32) # T_LiDAR_camera
731
+ camera_to_lidar = camera_to_lidar @ np.array(
732
+ [[0, 0, 1, 0], [1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]],
733
+ dtype=np.float32) # UE4 coord to opencv coord
734
+ camera_intrinsic = np.array(params["camera%d" % camera_id]["intrinsic"]).astype(
735
+ np.float32
736
+ )
737
+ return camera_to_lidar, camera_intrinsic
738
+ def get_ext_int_2(self, params, camera_id):
739
+ camera_extrinsic = np.array(params["camera%d" % camera_id]["extrinsic"]).astype(
740
+ np.float32)
741
+ camera_extrinsic = camera_extrinsic @ np.array(
742
+ [[0, 0, 1, 0], [1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]],
743
+ dtype=np.float32) # UE4 coord to opencv coord
744
+ camera_intrinsic = np.array(params["camera%d" % camera_id]["intrinsic"]).astype(
745
+ np.float32
746
+ )
747
+ return camera_extrinsic, camera_intrinsic
748
+ VALUES = [255]
749
+ EXTENT = [0]
750
+ def generate_heatmap_multiclass(measurements, actors_data, max_distance=30, pixels_per_meter=8):
751
+ actors_data_multiclass = {
752
+ 0: {}, 1: {}, 2:{}, 3:{}
753
+ }
754
+ for _id in actors_data.keys():
755
+ actors_data_multiclass[actors_data[_id]['tpe']][_id] = actors_data[_id]
756
+ heatmap_0 = generate_heatmap(measurements, actors_data_multiclass[0], max_distance, pixels_per_meter)
757
+ heatmap_1 = generate_heatmap(measurements, actors_data_multiclass[1], max_distance, pixels_per_meter)
758
+ # heatmap_2 = generate_heatmap(measurements, actors_data_multiclass[2], max_distance, pixels_per_meter) # traffic light, not used
759
+ heatmap_3 = generate_heatmap(measurements, actors_data_multiclass[3], max_distance, pixels_per_meter)
760
+ return {0: heatmap_0, 1: heatmap_1, 3: heatmap_3}
761
+
762
+ def get_yaw_angle(forward_vector):
763
+ forward_vector = forward_vector / np.linalg.norm(forward_vector)
764
+ yaw = math.acos(forward_vector[0])
765
+ if forward_vector[1] < 0:
766
+ yaw = 2 * np.pi - yaw
767
+ return yaw
768
+
769
+ def generate_heatmap(measurements, actors_data, max_distance=30, pixels_per_meter=8):
770
+ img_size = max_distance * pixels_per_meter * 2
771
+ img = np.zeros((img_size, img_size, 3), np.int)
772
+ ego_x = measurements["lidar_pose_x"]
773
+ ego_y = measurements["lidar_pose_y"]
774
+ ego_theta = measurements["theta"]
775
+ R = np.array(
776
+ [
777
+ [np.cos(ego_theta), -np.sin(ego_theta)],
778
+ [np.sin(ego_theta), np.cos(ego_theta)],
779
+ ]
780
+ )
781
+ ego_id = None
782
+ for _id in actors_data:
783
+ color = np.array([1, 1, 1])
784
+ if actors_data[_id]["tpe"] == 2:
785
+ if int(_id) == int(measurements["affected_light_id"]):
786
+ if actors_data[_id]["sta"] == 0:
787
+ color = np.array([1, 1, 1])
788
+ else:
789
+ color = np.array([0, 0, 0])
790
+ yaw = get_yaw_angle(actors_data[_id]["ori"])
791
+ TR = np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])
792
+ actors_data[_id]["loc"] = np.array(
793
+ actors_data[_id]["loc"][:2]
794
+ ) + TR.T.dot(np.array(actors_data[_id]["taigger_loc"])[:2])
795
+ actors_data[_id]["ori"] = np.array(actors_data[_id]["ori"])
796
+ actors_data[_id]["box"] = np.array(actors_data[_id]["trigger_box"]) * 2
797
+ else:
798
+ continue
799
+ raw_loc = actors_data[_id]["loc"]
800
+ if (raw_loc[0] - ego_x) ** 2 + (raw_loc[1] - ego_y) ** 2 <= 2:
801
+ ego_id = _id
802
+ color = np.array([0, 1, 1])
803
+ new_loc = R.T.dot(np.array([raw_loc[0] - ego_x, raw_loc[1] - ego_y]))
804
+ actors_data[_id]["loc"] = np.array(new_loc)
805
+ raw_ori = actors_data[_id]["ori"]
806
+ new_ori = R.T.dot(np.array([raw_ori[0], raw_ori[1]]))
807
+ actors_data[_id]["ori"] = np.array(new_ori)
808
+ actors_data[_id]["box"] = np.array(actors_data[_id]["box"])
809
+ if int(_id) in measurements["is_vehicle_present"]:
810
+ color = np.array([1, 1, 1])
811
+ elif int(_id) in measurements["is_bike_present"]:
812
+ color = np.array([1, 1, 1])
813
+ elif int(_id) in measurements["is_junction_vehicle_present"]:
814
+ color = np.array([1, 1, 1])
815
+ elif int(_id) in measurements["is_pedestrian_present"]:
816
+ color = np.array([1, 1, 1])
817
+ actors_data[_id]["color"] = color
818
+
819
+ if ego_id is not None and ego_id in actors_data:
820
+ del actors_data[ego_id] # Do not show ego car
821
+ for _id in actors_data:
822
+ if actors_data[_id]["tpe"] == 2:
823
+ continue # FIXME donot add traffix light
824
+ if int(_id) != int(measurements["affected_light_id"]):
825
+ continue
826
+ if actors_data[_id]["sta"] != 0:
827
+ continue
828
+ act_img = np.zeros((img_size, img_size, 3), np.uint8)
829
+ loc = actors_data[_id]["loc"][:2]
830
+ ori = actors_data[_id]["ori"][:2]
831
+ box = actors_data[_id]["box"]
832
+ if box[0] < 1.5:
833
+ box = box * 1.5 # FIXME enlarge the size of pedstrian and bike
834
+ color = actors_data[_id]["color"]
835
+ for i in range(len(VALUES)):
836
+ act_img = add_rect(
837
+ act_img,
838
+ loc,
839
+ ori,
840
+ box + EXTENT[i],
841
+ VALUES[i],
842
+ pixels_per_meter,
843
+ max_distance,
844
+ color,
845
+ )
846
+ act_img = np.clip(act_img, 0, 255)
847
+ img = img + act_img
848
+ img = np.clip(img, 0, 255)
849
+ img = img.astype(np.uint8)
850
+ img = img[:, :, 0]
851
+ return img
852
+
853
+ def add_rect(img, loc, ori, box, value, pixels_per_meter, max_distance, color):
854
+ img_size = max_distance * pixels_per_meter * 2
855
+ vet_ori = np.array([-ori[1], ori[0]])
856
+ hor_offset = box[0] * ori
857
+ vet_offset = box[1] * vet_ori
858
+ left_up = (loc + hor_offset + vet_offset + max_distance) * pixels_per_meter
859
+ left_down = (loc + hor_offset - vet_offset + max_distance) * pixels_per_meter
860
+ right_up = (loc - hor_offset + vet_offset + max_distance) * pixels_per_meter
861
+ right_down = (loc - hor_offset - vet_offset + max_distance) * pixels_per_meter
862
+ left_up = np.around(left_up).astype(np.int)
863
+ left_down = np.around(left_down).astype(np.int)
864
+ right_down = np.around(right_down).astype(np.int)
865
+ right_up = np.around(right_up).astype(np.int)
866
+ left_up = list(left_up)
867
+ left_down = list(left_down)
868
+ right_up = list(right_up)
869
+ right_down = list(right_down)
870
+ color = [int(x) for x in value * color]
871
+ cv2.fillConvexPoly(img, np.array([left_up, left_down, right_down, right_up]), color)
872
+ return img
873
+
874
+ def generate_det_data_multiclass(
875
+ heatmap, measurements, actors_data, det_range=[30,10,10,10, 0.8]
876
+ ):
877
+ actors_data_multiclass = {
878
+ 0: {}, 1: {}, 2: {}, 3:{}
879
+ }
880
+ for _id in actors_data.keys():
881
+ actors_data_multiclass[actors_data[_id]['tpe']][_id] = actors_data[_id]
882
+ det_data = []
883
+ for _class in range(4):
884
+ if _class != 2:
885
+ det_data.append(generate_det_data(heatmap[_class], measurements, actors_data_multiclass[_class], det_range))
886
+
887
+ return np.array(det_data)
888
+
889
+ from skimage.measure import block_reduce
890
+
891
+ def generate_det_data(
892
+ heatmap, measurements, actors_data, det_range=[30,10,10,10, 0.8]
893
+ ):
894
+ res = det_range[4]
895
+ max_distance = max(det_range)
896
+ traffic_heatmap = block_reduce(heatmap, block_size=(int(8*res), int(8*res)), func=np.mean)
897
+ traffic_heatmap = np.clip(traffic_heatmap, 0.0, 255.0)
898
+ traffic_heatmap = traffic_heatmap[:int((det_range[0]+det_range[1])/res), int((max_distance-det_range[2])/res):int((max_distance+det_range[3])/res)]
899
+ det_data = np.zeros((int((det_range[0]+det_range[1])/res), int((det_range[2]+det_range[3])/res), 7)) # (50,25,7)
900
+ vertical, horizontal = det_data.shape[:2]
901
+
902
+ ego_x = measurements["lidar_pose_x"]
903
+ ego_y = measurements["lidar_pose_y"]
904
+ ego_theta = measurements["theta"]
905
+ R = np.array(
906
+ [
907
+ [np.cos(ego_theta), -np.sin(ego_theta)],
908
+ [np.sin(ego_theta), np.cos(ego_theta)],
909
+ ]
910
+ )
911
+ need_deleted_ids = []
912
+ for _id in actors_data:
913
+ raw_loc = actors_data[_id]["loc"]
914
+ new_loc = R.T.dot(np.array([raw_loc[0] - ego_x, raw_loc[1] - ego_y]))
915
+ new_loc[1] = -new_loc[1]
916
+ actors_data[_id]["loc"] = np.array(new_loc)
917
+ raw_ori = actors_data[_id]["ori"]
918
+ new_ori = R.T.dot(np.array([raw_ori[0], raw_ori[1]]))
919
+ dis = new_loc[0] ** 2 + new_loc[1] ** 2
920
+ if (
921
+ dis <= 2
922
+ or dis >= (max_distance) ** 2 * 2
923
+ or "box" not in actors_data[_id]
924
+ or actors_data[_id]['tpe'] == 2
925
+ ):
926
+ need_deleted_ids.append(_id)
927
+ continue
928
+ actors_data[_id]["ori"] = np.array(new_ori)
929
+ actors_data[_id]["box"] = np.array(actors_data[_id]["box"])
930
+
931
+ for _id in need_deleted_ids:
932
+ del actors_data[_id]
933
+
934
+ for i in range(vertical): # 50
935
+ for j in range(horizontal): # 25
936
+ if traffic_heatmap[i][j] < 0.05 * 255.0:
937
+ continue
938
+ center_x, center_y = convert_grid_to_xy(i, j, det_range)
939
+ min_dis = 1000
940
+ min_id = None
941
+ for _id in actors_data:
942
+ loc = actors_data[_id]["loc"][:2]
943
+ ori = actors_data[_id]["ori"][:2]
944
+ box = actors_data[_id]["box"]
945
+ dis = (loc[0] - center_x) ** 2 + (loc[1] - center_y) ** 2
946
+ if dis < min_dis:
947
+ min_dis = dis
948
+ min_id = _id
949
+
950
+ if min_id is None:
951
+ continue
952
+
953
+ loc = actors_data[min_id]["loc"][:2]
954
+ ori = actors_data[min_id]["ori"][:2]
955
+ box = actors_data[min_id]["box"]
956
+ theta = (get_yaw_angle(ori) / np.pi + 2) % 2
957
+ speed = np.linalg.norm(actors_data[min_id]["vel"])
958
+
959
+ # prob = np.power(0.5 / max(0.5, np.sqrt(min_dis)), 0.5)
960
+
961
+ det_data[i][j] = np.array(
962
+ [
963
+ 0,
964
+ (loc[0] - center_x) * 3.0,
965
+ (loc[1] - center_y) * 3.0,
966
+ theta / 2.0,
967
+ box[0] / 7.0,
968
+ box[1] / 4.0,
969
+ 0,
970
+ ]
971
+ )
972
+
973
+ heatmap = np.zeros((int((det_range[0]+det_range[1])/res), int((det_range[2]+det_range[3])/res))) # (50,25)
974
+ for _id in actors_data:
975
+ loc = actors_data[_id]["loc"][:2]
976
+ ori = actors_data[_id]["ori"][:2]
977
+ box = actors_data[_id]["box"]
978
+ try:
979
+ x,y = loc
980
+ i,j = convert_xy_to_grid(x,y,det_range)
981
+ i = int(np.around(i))
982
+ j = int(np.around(j))
983
+
984
+ if i < vertical and i > 0 and j > 0 and j < horizontal:
985
+ det_data[i][j][-1] = 1.0
986
+
987
+ ################## Gaussian Heatmap #####################
988
+ w, h = box[:2]/det_range[4]
989
+ heatmap = draw_heatmap(heatmap, h, w, j, i)
990
+ #########################################################
991
+
992
+ # theta = (get_yaw_angle(ori) / np.pi + 2) % 2
993
+ # center_x, center_y = convert_grid_to_xy(i, j, det_range)
994
+
995
+ # det_data[i][j] = np.array(
996
+ # [
997
+ # 0,
998
+ # (loc[0] - center_x) * 3.0,
999
+ # (loc[1] - center_y) * 3.0,
1000
+ # theta / 2.0,
1001
+ # box[0] / 7.0,
1002
+ # box[1] / 4.0,
1003
+ # 0,
1004
+ # ]
1005
+ # )
1006
+
1007
+ except:
1008
+ print('actor data error, skip!')
1009
+ det_data[:,:,0] = heatmap
1010
+ return det_data
1011
+
1012
+ def convert_grid_to_xy(i, j, det_range):
1013
+ x = det_range[4]*(j + 0.5) - det_range[2]
1014
+ y = det_range[0] - det_range[4]*(i+0.5)
1015
+ return x, y
1016
+
1017
+ def convert_xy_to_grid(x, y, det_range):
1018
+ j = (x + det_range[2]) / det_range[4] - 0.5
1019
+ i = (det_range[0] - y) / det_range[4] - 0.5
1020
+ return i, j
1021
+
1022
+ def draw_heatmap(heatmap, h, w, x, y):
1023
+ feature_map_size = heatmap.shape
1024
+ radius = gaussian_radius(
1025
+ (h, w),
1026
+ min_overlap=0.1)
1027
+ radius = max(2, int(radius))
1028
+
1029
+ # throw out not in range objects to avoid out of array
1030
+ # area when creating the heatmap
1031
+ if not (0 <= y < feature_map_size[0]
1032
+ and 0 <= x < feature_map_size[1]):
1033
+ return heatmap
1034
+
1035
+ heatmap = draw_gaussian(heatmap, (x,y), radius)
1036
+ return heatmap
1037
+
1038
+ def draw_gaussian(heatmap, center, radius, k=1):
1039
+ """Get gaussian masked heatmap.
1040
+
1041
+ Args:
1042
+ heatmap (torch.Tensor): Heatmap to be masked.
1043
+ center (torch.Tensor): Center coord of the heatmap.
1044
+ radius (int): Radius of gausian.
1045
+ K (int): Multiple of masked_gaussian. Defaults to 1.
1046
+
1047
+ Returns:
1048
+ torch.Tensor: Masked heatmap.
1049
+ """
1050
+ diameter = 2 * radius + 1
1051
+ gaussian = gaussian_2d((diameter, diameter), sigma=diameter / 6)
1052
+
1053
+ x, y = int(center[0]), int(center[1])
1054
+
1055
+ height, width = heatmap.shape[0:2]
1056
+
1057
+ left, right = min(x, radius), min(width - x, radius + 1)
1058
+ top, bottom = min(y, radius), min(height - y, radius + 1)
1059
+
1060
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
1061
+ masked_gaussian = gaussian[radius - top:radius + bottom,
1062
+ radius - left:radius + right]
1063
+
1064
+ if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
1065
+ # torch.max(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
1066
+ np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
1067
+ # masked_heatmap = np.max([masked_heatmap[None,], (masked_gaussian * k)[None,]], axis=0)[0]
1068
+ # heatmap[y - top:y + bottom, x - left:x + right] = masked_heatmap
1069
+ return heatmap
1070
+
1071
+ def gaussian_2d(shape, sigma=1):
1072
+ """Generate gaussian map.
1073
+
1074
+ Args:
1075
+ shape (list[int]): Shape of the map.
1076
+ sigma (float): Sigma to generate gaussian map.
1077
+ Defaults to 1.
1078
+
1079
+ Returns:
1080
+ np.ndarray: Generated gaussian map.
1081
+ """
1082
+ m, n = [(ss - 1.) / 2. for ss in shape]
1083
+ y, x = np.ogrid[-m:m + 1, -n:n + 1]
1084
+
1085
+ h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
1086
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
1087
+ return h
1088
+
1089
+ def gaussian_radius(det_size, min_overlap=0.5):
1090
+ """Get radius of gaussian.
1091
+
1092
+ Args:
1093
+ det_size (tuple[torch.Tensor]): Size of the detection result.
1094
+ min_overlap (float): Gaussian_overlap. Defaults to 0.5.
1095
+
1096
+ Returns:
1097
+ torch.Tensor: Computed radius.
1098
+ """
1099
+ height, width = det_size
1100
+
1101
+ a1 = 1
1102
+ b1 = (height + width)
1103
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
1104
+ sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
1105
+ r1 = (b1 + sq1) / (2 * a1)
1106
+
1107
+ a2 = 4
1108
+ b2 = 2 * (height + width)
1109
+ c2 = (1 - min_overlap) * width * height
1110
+ sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
1111
+ r2 = (b2 + sq2) / (2 * a2)
1112
+
1113
+ a3 = 4 * min_overlap
1114
+ b3 = -2 * min_overlap * (height + width)
1115
+ c3 = (min_overlap - 1) * width * height
1116
+ sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
1117
+ r3 = (b3 + sq3) / (2 * a3)
1118
+ return min(r1, r2, r3)
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/early_fusion_dataset.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # early fusion dataset
2
+ import torch
3
+ import numpy as np
4
+ from opencood.utils.pcd_utils import downsample_lidar_minimum
5
+ import math
6
+ from collections import OrderedDict
7
+
8
+ from opencood.utils import box_utils
9
+ from opencood.utils.common_utils import merge_features_to_dict
10
+ from opencood.data_utils.post_processor import build_postprocessor
11
+ from opencood.data_utils.pre_processor import build_preprocessor
12
+ from opencood.hypes_yaml.yaml_utils import load_yaml
13
+ from opencood.utils.pcd_utils import \
14
+ mask_points_by_range, mask_ego_points, shuffle_points, \
15
+ downsample_lidar_minimum
16
+ from opencood.utils.transformation_utils import x1_to_x2
17
+
18
+
19
+ def getEarlyFusionDataset(cls):
20
+ class EarlyFusionDataset(cls):
21
+ """
22
+ This dataset is used for early fusion, where each CAV transmit the raw
23
+ point cloud to the ego vehicle.
24
+ """
25
+ def __init__(self, params, visualize, train=True):
26
+ super(EarlyFusionDataset, self).__init__(params, visualize, train)
27
+ self.supervise_single = True if ('supervise_single' in params['model']['args'] and params['model']['args']['supervise_single']) \
28
+ else False
29
+ assert self.supervise_single is False
30
+ self.proj_first = False if 'proj_first' not in params['fusion']['args']\
31
+ else params['fusion']['args']['proj_first']
32
+ self.anchor_box = self.post_processor.generate_anchor_box()
33
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
34
+
35
+ self.heterogeneous = False
36
+ if 'heter' in params:
37
+ self.heterogeneous = True
38
+
39
+ def __getitem__(self, idx):
40
+ base_data_dict = self.retrieve_base_data(idx)
41
+
42
+ processed_data_dict = OrderedDict()
43
+ processed_data_dict['ego'] = {}
44
+
45
+ ego_id = -1
46
+ ego_lidar_pose = []
47
+
48
+ # first find the ego vehicle's lidar pose
49
+ for cav_id, cav_content in base_data_dict.items():
50
+ if cav_content['ego']:
51
+ ego_id = cav_id
52
+ ego_lidar_pose = cav_content['params']['lidar_pose']
53
+ break
54
+
55
+ assert ego_id != -1
56
+ assert len(ego_lidar_pose) > 0
57
+
58
+ projected_lidar_stack = []
59
+ object_stack = []
60
+ object_id_stack = []
61
+
62
+ # loop over all CAVs to process information
63
+ for cav_id, selected_cav_base in base_data_dict.items():
64
+ # check if the cav is within the communication range with ego
65
+ distance = \
66
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
67
+ ego_lidar_pose[0]) ** 2 + (
68
+ selected_cav_base['params'][
69
+ 'lidar_pose'][1] - ego_lidar_pose[
70
+ 1]) ** 2)
71
+ if distance > self.params['comm_range']:
72
+ continue
73
+
74
+ selected_cav_processed = self.get_item_single_car(
75
+ selected_cav_base,
76
+ ego_lidar_pose)
77
+ # all these lidar and object coordinates are projected to ego
78
+ # already.
79
+ projected_lidar_stack.append(
80
+ selected_cav_processed['projected_lidar'])
81
+ object_stack.append(selected_cav_processed['object_bbx_center'])
82
+ object_id_stack += selected_cav_processed['object_ids']
83
+
84
+ # exclude all repetitive objects
85
+ unique_indices = \
86
+ [object_id_stack.index(x) for x in set(object_id_stack)]
87
+ object_stack = np.vstack(object_stack)
88
+ object_stack = object_stack[unique_indices]
89
+
90
+ # make sure bounding boxes across all frames have the same number
91
+ object_bbx_center = \
92
+ np.zeros((self.params['postprocess']['max_num'], 7))
93
+ mask = np.zeros(self.params['postprocess']['max_num'])
94
+ object_bbx_center[:object_stack.shape[0], :] = object_stack
95
+ mask[:object_stack.shape[0]] = 1
96
+
97
+ # convert list to numpy array, (N, 4)
98
+ projected_lidar_stack = np.vstack(projected_lidar_stack)
99
+
100
+ # data augmentation
101
+ projected_lidar_stack, object_bbx_center, mask = \
102
+ self.augment(projected_lidar_stack, object_bbx_center, mask)
103
+
104
+ # we do lidar filtering in the stacked lidar
105
+ projected_lidar_stack = mask_points_by_range(projected_lidar_stack,
106
+ self.params['preprocess'][
107
+ 'cav_lidar_range'])
108
+ # augmentation may remove some of the bbx out of range
109
+ object_bbx_center_valid = object_bbx_center[mask == 1]
110
+ object_bbx_center_valid, range_mask = \
111
+ box_utils.mask_boxes_outside_range_numpy(object_bbx_center_valid,
112
+ self.params['preprocess'][
113
+ 'cav_lidar_range'],
114
+ self.params['postprocess'][
115
+ 'order'],
116
+ return_mask=True
117
+ )
118
+ mask[object_bbx_center_valid.shape[0]:] = 0
119
+ object_bbx_center[:object_bbx_center_valid.shape[0]] = \
120
+ object_bbx_center_valid
121
+ object_bbx_center[object_bbx_center_valid.shape[0]:] = 0
122
+ unique_indices = list(np.array(unique_indices)[range_mask])
123
+
124
+ # pre-process the lidar to voxel/bev/downsampled lidar
125
+ lidar_dict = self.pre_processor.preprocess(projected_lidar_stack)
126
+
127
+ # generate the anchor boxes
128
+ anchor_box = self.post_processor.generate_anchor_box()
129
+
130
+ # generate targets label
131
+ label_dict = \
132
+ self.post_processor.generate_label(
133
+ gt_box_center=object_bbx_center,
134
+ anchors=anchor_box,
135
+ mask=mask)
136
+
137
+ processed_data_dict['ego'].update(
138
+ {'object_bbx_center': object_bbx_center,
139
+ 'object_bbx_mask': mask,
140
+ 'object_ids': [object_id_stack[i] for i in unique_indices],
141
+ 'anchor_box': anchor_box,
142
+ 'processed_lidar': lidar_dict,
143
+ 'label_dict': label_dict})
144
+
145
+ if self.visualize:
146
+ processed_data_dict['ego'].update({'origin_lidar':
147
+ projected_lidar_stack})
148
+
149
+ return processed_data_dict
150
+
151
+ def get_item_single_car(self, selected_cav_base, ego_pose):
152
+ """
153
+ Project the lidar and bbx to ego space first, and then do clipping.
154
+
155
+ Parameters
156
+ ----------
157
+ selected_cav_base : dict
158
+ The dictionary contains a single CAV's raw information.
159
+ ego_pose : list
160
+ The ego vehicle lidar pose under world coordinate.
161
+
162
+ Returns
163
+ -------
164
+ selected_cav_processed : dict
165
+ The dictionary contains the cav's processed information.
166
+ """
167
+ selected_cav_processed = {}
168
+
169
+ # calculate the transformation matrix
170
+ transformation_matrix = \
171
+ x1_to_x2(selected_cav_base['params']['lidar_pose'],
172
+ ego_pose)
173
+
174
+ # retrieve objects under ego coordinates
175
+ object_bbx_center, object_bbx_mask, object_ids = \
176
+ self.generate_object_center([selected_cav_base],
177
+ ego_pose)
178
+
179
+ # filter lidar
180
+ lidar_np = selected_cav_base['lidar_np']
181
+ lidar_np = shuffle_points(lidar_np)
182
+ # remove points that hit itself
183
+ lidar_np = mask_ego_points(lidar_np)
184
+ # project the lidar to ego space
185
+ lidar_np[:, :3] = \
186
+ box_utils.project_points_by_matrix_torch(lidar_np[:, :3],
187
+ transformation_matrix)
188
+
189
+ selected_cav_processed.update(
190
+ {'object_bbx_center': object_bbx_center[object_bbx_mask == 1],
191
+ 'object_ids': object_ids,
192
+ 'projected_lidar': lidar_np})
193
+
194
+ return selected_cav_processed
195
+
196
+ def collate_batch_test(self, batch):
197
+ """
198
+ Customized collate function for pytorch dataloader during testing
199
+ for late fusion dataset.
200
+
201
+ Parameters
202
+ ----------
203
+ batch : dict
204
+
205
+ Returns
206
+ -------
207
+ batch : dict
208
+ Reformatted batch.
209
+ """
210
+ # currently, we only support batch size of 1 during testing
211
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
212
+ batch = batch[0] # only ego
213
+
214
+ output_dict = {}
215
+
216
+ for cav_id, cav_content in batch.items():
217
+ output_dict.update({cav_id: {}})
218
+ # shape: (1, max_num, 7)
219
+ object_bbx_center = \
220
+ torch.from_numpy(np.array([cav_content['object_bbx_center']]))
221
+ object_bbx_mask = \
222
+ torch.from_numpy(np.array([cav_content['object_bbx_mask']]))
223
+ object_ids = cav_content['object_ids']
224
+
225
+ # the anchor box is the same for all bounding boxes usually, thus
226
+ # we don't need the batch dimension.
227
+ if cav_content['anchor_box'] is not None:
228
+ output_dict[cav_id].update({'anchor_box':
229
+ torch.from_numpy(np.array(
230
+ cav_content[
231
+ 'anchor_box']))})
232
+ if self.visualize:
233
+ origin_lidar = [cav_content['origin_lidar']]
234
+
235
+ # processed lidar dictionary
236
+ processed_lidar_torch_dict = \
237
+ self.pre_processor.collate_batch(
238
+ [cav_content['processed_lidar']])
239
+ # label dictionary
240
+ label_torch_dict = \
241
+ self.post_processor.collate_batch([cav_content['label_dict']])
242
+
243
+ # save the transformation matrix (4, 4) to ego vehicle
244
+ transformation_matrix_torch = \
245
+ torch.from_numpy(np.identity(4)).float()
246
+ transformation_matrix_clean_torch = \
247
+ torch.from_numpy(np.identity(4)).float()
248
+
249
+ output_dict[cav_id].update({'object_bbx_center': object_bbx_center,
250
+ 'object_bbx_mask': object_bbx_mask,
251
+ 'processed_lidar': processed_lidar_torch_dict,
252
+ 'label_dict': label_torch_dict,
253
+ 'object_ids': object_ids,
254
+ 'transformation_matrix': transformation_matrix_torch,
255
+ 'transformation_matrix_clean': transformation_matrix_clean_torch})
256
+
257
+ if self.visualize:
258
+ origin_lidar = \
259
+ np.array(
260
+ downsample_lidar_minimum(pcd_np_list=origin_lidar))
261
+ origin_lidar = torch.from_numpy(origin_lidar)
262
+ output_dict[cav_id].update({'origin_lidar': origin_lidar})
263
+
264
+ return output_dict
265
+
266
+ def collate_batch_train(self, batch):
267
+ # Intermediate fusion is different the other two
268
+ output_dict = {'ego': {}}
269
+
270
+ object_bbx_center = []
271
+ object_bbx_mask = []
272
+ object_ids = []
273
+ processed_lidar_list = []
274
+ image_inputs_list = []
275
+ # used to record different scenario
276
+ label_dict_list = []
277
+ origin_lidar = []
278
+
279
+ # heterogeneous
280
+ lidar_agent_list = []
281
+
282
+ # pairwise transformation matrix
283
+ pairwise_t_matrix_list = []
284
+
285
+ ### 2022.10.10 single gt ####
286
+ if self.supervise_single:
287
+ pos_equal_one_single = []
288
+ neg_equal_one_single = []
289
+ targets_single = []
290
+
291
+ for i in range(len(batch)):
292
+ ego_dict = batch[i]['ego']
293
+ object_bbx_center.append(ego_dict['object_bbx_center'])
294
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
295
+ object_ids.append(ego_dict['object_ids'])
296
+ if self.load_lidar_file:
297
+ processed_lidar_list.append(ego_dict['processed_lidar'])
298
+ if self.load_camera_file:
299
+ image_inputs_list.append(ego_dict['image_inputs']) # different cav_num, ego_dict['image_inputs'] is dict.
300
+
301
+ label_dict_list.append(ego_dict['label_dict'])
302
+
303
+ if self.visualize:
304
+ origin_lidar.append(ego_dict['origin_lidar'])
305
+
306
+ ### 2022.10.10 single gt ####
307
+ if self.supervise_single:
308
+ pos_equal_one_single.append(ego_dict['single_label_dict_torch']['pos_equal_one'])
309
+ neg_equal_one_single.append(ego_dict['single_label_dict_torch']['neg_equal_one'])
310
+ targets_single.append(ego_dict['single_label_dict_torch']['targets'])
311
+
312
+ # heterogeneous
313
+ if self.heterogeneous:
314
+ lidar_agent_list.append(ego_dict['lidar_agent'])
315
+
316
+ # convert to numpy, (B, max_num, 7)
317
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
318
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
319
+
320
+ if self.load_lidar_file:
321
+ merged_feature_dict = merge_features_to_dict(processed_lidar_list)
322
+
323
+ if self.heterogeneous:
324
+ lidar_agent = np.concatenate(lidar_agent_list)
325
+ lidar_agent_idx = lidar_agent.nonzero()[0].tolist()
326
+ for k, v in merged_feature_dict.items(): # 'voxel_features' 'voxel_num_points' 'voxel_coords'
327
+ merged_feature_dict[k] = [v[index] for index in lidar_agent_idx]
328
+
329
+ if not self.heterogeneous or (self.heterogeneous and sum(lidar_agent) != 0):
330
+ processed_lidar_torch_dict = \
331
+ self.pre_processor.collate_batch(merged_feature_dict)
332
+ output_dict['ego'].update({'processed_lidar': processed_lidar_torch_dict})
333
+
334
+ if self.load_camera_file:
335
+ merged_image_inputs_dict = merge_features_to_dict(image_inputs_list, merge='cat')
336
+
337
+ if self.heterogeneous:
338
+ camera_agent = 1 - lidar_agent
339
+ camera_agent_idx = camera_agent.nonzero()[0].tolist()
340
+ if sum(camera_agent) != 0:
341
+ for k, v in merged_image_inputs_dict.items(): # 'imgs' 'rots' 'trans' ...
342
+ merged_image_inputs_dict[k] = torch.stack([v[index] for index in camera_agent_idx])
343
+
344
+ if not self.heterogeneous or (self.heterogeneous and sum(camera_agent) != 0):
345
+ output_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
346
+
347
+ label_torch_dict = \
348
+ self.post_processor.collate_batch(label_dict_list)
349
+
350
+ # for centerpoint
351
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
352
+ 'object_bbx_mask': object_bbx_mask})
353
+
354
+ # (B, max_cav)
355
+ pairwise_t_matrix = torch.from_numpy(np.array(pairwise_t_matrix_list))
356
+
357
+ # add pairwise_t_matrix to label dict
358
+
359
+ # object id is only used during inference, where batch size is 1.
360
+ # so here we only get the first element.
361
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
362
+ 'object_bbx_mask': object_bbx_mask,
363
+ 'label_dict': label_torch_dict,
364
+ 'object_ids': object_ids[0]})
365
+
366
+
367
+ if self.visualize:
368
+ origin_lidar = \
369
+ np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
370
+ origin_lidar = torch.from_numpy(origin_lidar)
371
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
372
+
373
+ if self.supervise_single:
374
+ output_dict['ego'].update({
375
+ "label_dict_single" :
376
+ {"pos_equal_one": torch.cat(pos_equal_one_single, dim=0),
377
+ "neg_equal_one": torch.cat(neg_equal_one_single, dim=0),
378
+ "targets": torch.cat(targets_single, dim=0)}
379
+ })
380
+
381
+ if self.heterogeneous:
382
+ output_dict['ego'].update({
383
+ "lidar_agent_record": torch.from_numpy(np.concatenate(lidar_agent_list)) # [0,1,1,0,1...]
384
+ })
385
+
386
+ return output_dict
387
+
388
+ def post_process(self, data_dict, output_dict):
389
+ """
390
+ Process the outputs of the model to 2D/3D bounding box.
391
+
392
+ Parameters
393
+ ----------
394
+ data_dict : dict
395
+ The dictionary containing the origin input data of model.
396
+
397
+ output_dict :dict
398
+ The dictionary containing the output of the model.
399
+
400
+ Returns
401
+ -------
402
+ pred_box_tensor : torch.Tensor
403
+ The tensor of prediction bounding box after NMS.
404
+ gt_box_tensor : torch.Tensor
405
+ The tensor of gt bounding box.
406
+ """
407
+ pred_box_tensor, pred_score = \
408
+ self.post_processor.post_process(data_dict, output_dict)
409
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
410
+
411
+ return pred_box_tensor, pred_score, gt_box_tensor
412
+
413
+ return EarlyFusionDataset
414
+
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/early_multiclass_fusion_dataset.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # early fusion dataset
2
+ import random
3
+ import math
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+ import torch
7
+ import copy
8
+ from icecream import ic
9
+ from PIL import Image
10
+ import pickle as pkl
11
+ from opencood.utils import box_utils as box_utils
12
+ from opencood.data_utils.pre_processor import build_preprocessor
13
+ from opencood.data_utils.post_processor import build_postprocessor
14
+ from opencood.utils.camera_utils import (
15
+ sample_augmentation,
16
+ img_transform,
17
+ normalize_img,
18
+ img_to_tensor,
19
+ )
20
+ # from opencood.utils.heter_utils import AgentSelector
21
+ from opencood.utils.common_utils import merge_features_to_dict
22
+ from opencood.utils.transformation_utils import x1_to_x2, x_to_world, get_pairwise_transformation
23
+ from opencood.utils.pose_utils import add_noise_data_dict, add_noise_data_dict_asymmetric
24
+ from opencood.utils.pcd_utils import (
25
+ mask_points_by_range,
26
+ mask_ego_points,
27
+ mask_ego_points_v2,
28
+ shuffle_points,
29
+ downsample_lidar_minimum,
30
+ )
31
+ from opencood.utils.common_utils import read_json
32
+
33
+
34
+ def getEarlymulticlassFusionDataset(cls):
35
+ """
36
+ cls: the Basedataset.
37
+ """
38
+ class EarlymulticlassFusionDataset(cls):
39
+ def __init__(self, params, visualize, train=True):
40
+ super().__init__(params, visualize, train)
41
+ # supervise single
42
+ self.supervise_single = True if ('supervise_single' in params['model']['args'] and params['model']['args']['supervise_single']) \
43
+ else False
44
+ self.proj_first = False if 'proj_first' not in params['fusion']['args']\
45
+ else params['fusion']['args']['proj_first']
46
+
47
+ self.anchor_box = self.post_processor.generate_anchor_box()
48
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
49
+
50
+ self.heterogeneous = False
51
+ if 'heter' in params:
52
+ self.heterogeneous = True
53
+ self.selector = AgentSelector(params['heter'], self.max_cav)
54
+ self.kd_flag = params.get('kd_flag', False)
55
+ self.box_align = False
56
+ if "box_align" in params:
57
+ self.box_align = True
58
+ self.stage1_result_path = params['box_align']['train_result'] if train else params['box_align']['val_result']
59
+ self.stage1_result = read_json(self.stage1_result_path)
60
+ self.box_align_args = params['box_align']['args']
61
+ self.multiclass = params['model']['args']['multi_class']
62
+ self.online_eval_only = False
63
+
64
+ def get_item_single_car(self, selected_cav_base, ego_cav_base, base_data_dict, tpe='all', cav_id='car_0', online_eval=False):
65
+ """
66
+ Process a single CAV's information for the train/test pipeline.
67
+
68
+
69
+ Parameters
70
+ ----------
71
+ selected_cav_base : dict
72
+ The dictionary contains a single CAV's raw information.
73
+ including 'params', 'camera_data'
74
+ ego_pose : list, length 6
75
+ The ego vehicle lidar pose under world coordinate.
76
+ ego_pose_clean : list, length 6
77
+ only used for gt box generation
78
+
79
+ Returns
80
+ -------
81
+ selected_cav_processed : dict
82
+ The dictionary contains the cav's processed information.
83
+ """
84
+ selected_cav_processed = {}
85
+ ego_pose, ego_pose_clean = ego_cav_base['params']['lidar_pose'], ego_cav_base['params']['lidar_pose_clean']
86
+ selected_pose, selected_pose_clean = selected_cav_base['params']['lidar_pose'], selected_cav_base['params']['lidar_pose_clean']
87
+
88
+ # calculate the transformation matrix
89
+ transformation_matrix = \
90
+ x1_to_x2(selected_cav_base['params']['lidar_pose'],
91
+ ego_pose) # T_ego_cav
92
+ transformation_matrix_clean = \
93
+ x1_to_x2(selected_cav_base['params']['lidar_pose_clean'],
94
+ ego_pose_clean)
95
+
96
+ # lidar
97
+ if tpe == 'all':
98
+ if self.load_lidar_file or self.visualize:
99
+ # process lidar
100
+ lidar_np = selected_cav_base['lidar_np']
101
+ lidar_np = shuffle_points(lidar_np)
102
+ # remove points that hit itself
103
+ if not cav_id.startswith('rsu'):
104
+ lidar_np = mask_ego_points_v2(lidar_np)
105
+ # project the lidar to ego space
106
+ # x,y,z in ego space
107
+
108
+ project_lidar_bank = []
109
+ lidar_bank = []
110
+ for agent_id in base_data_dict:
111
+ collab_cav_base = base_data_dict[agent_id]
112
+ collab_lidar_np = collab_cav_base['lidar_np']
113
+ collab_lidar_np = shuffle_points(collab_lidar_np)
114
+ # remove points that hit itself
115
+ if not agent_id.startswith('rsu'):
116
+ collab_lidar_np = mask_ego_points_v2(collab_lidar_np)
117
+ # project the lidar to ego space
118
+ # x,y,z in ego space
119
+
120
+ # calculate the transformation matrix
121
+ transformation_matrix_for_selected = \
122
+ x1_to_x2(collab_cav_base['params']['lidar_pose'],
123
+ selected_pose) # T_ego_cav
124
+
125
+ projected_collab_lidar = \
126
+ box_utils.project_points_by_matrix_torch(collab_lidar_np[:, :3],
127
+ transformation_matrix_for_selected)
128
+ project_lidar_bank.append(projected_collab_lidar)
129
+ lidar_bank.append(collab_lidar_np)
130
+
131
+ projected_lidar = np.concatenate(project_lidar_bank, axis=0)
132
+ lidar_np = np.concatenate(lidar_bank, axis=0)
133
+
134
+ # if self.proj_first:
135
+ lidar_np[:, :3] = projected_lidar
136
+ if self.visualize:
137
+ # filter lidar
138
+ if not selected_cav_base['ego']:
139
+ projected_lidar *= 0
140
+ selected_cav_processed.update({'projected_lidar': projected_lidar})
141
+
142
+ if self.kd_flag:
143
+ lidar_proj_np = copy.deepcopy(lidar_np)
144
+ lidar_proj_np[:,:3] = projected_lidar
145
+
146
+ selected_cav_processed.update({'projected_lidar': lidar_proj_np})
147
+
148
+ processed_lidar = self.pre_processor.preprocess(lidar_np)
149
+ selected_cav_processed.update({'processed_features': processed_lidar})
150
+
151
+ if not online_eval:
152
+ # generate targets label single GT, note the reference pose is itself.
153
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center(
154
+ [selected_cav_base], selected_cav_base['params']['lidar_pose']
155
+ )
156
+
157
+ label_dict = {}
158
+ if tpe == 'all':
159
+ # unused label
160
+ if False:
161
+ label_dict = self.post_processor.generate_label(
162
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
163
+ )
164
+ selected_cav_processed.update({
165
+ "single_label_dict": label_dict,
166
+ "single_object_bbx_center": object_bbx_center,
167
+ "single_object_bbx_mask": object_bbx_mask})
168
+
169
+ if tpe == 'all':
170
+ # camera
171
+ if self.load_camera_file:
172
+ camera_data_list = selected_cav_base["camera_data"]
173
+
174
+ params = selected_cav_base["params"]
175
+ imgs = []
176
+ rots = []
177
+ trans = []
178
+ intrins = []
179
+ extrinsics = []
180
+ post_rots = []
181
+ post_trans = []
182
+
183
+ for idx, img in enumerate(camera_data_list):
184
+ camera_to_lidar, camera_intrinsic = self.get_ext_int(params, idx)
185
+
186
+ intrin = torch.from_numpy(camera_intrinsic)
187
+ rot = torch.from_numpy(
188
+ camera_to_lidar[:3, :3]
189
+ ) # R_wc, we consider world-coord is the lidar-coord
190
+ tran = torch.from_numpy(camera_to_lidar[:3, 3]) # T_wc
191
+
192
+ post_rot = torch.eye(2)
193
+ post_tran = torch.zeros(2)
194
+
195
+ img_src = [img]
196
+
197
+ # depth
198
+ if self.load_depth_file:
199
+ depth_img = selected_cav_base["depth_data"][idx]
200
+ img_src.append(depth_img)
201
+ else:
202
+ depth_img = None
203
+
204
+ # data augmentation
205
+ resize, resize_dims, crop, flip, rotate = sample_augmentation(
206
+ self.data_aug_conf, self.train
207
+ )
208
+ img_src, post_rot2, post_tran2 = img_transform(
209
+ img_src,
210
+ post_rot,
211
+ post_tran,
212
+ resize=resize,
213
+ resize_dims=resize_dims,
214
+ crop=crop,
215
+ flip=flip,
216
+ rotate=rotate,
217
+ )
218
+ # for convenience, make augmentation matrices 3x3
219
+ post_tran = torch.zeros(3)
220
+ post_rot = torch.eye(3)
221
+ post_tran[:2] = post_tran2
222
+ post_rot[:2, :2] = post_rot2
223
+
224
+ # decouple RGB and Depth
225
+
226
+ img_src[0] = normalize_img(img_src[0])
227
+ if self.load_depth_file:
228
+ img_src[1] = img_to_tensor(img_src[1]) * 255
229
+
230
+ imgs.append(torch.cat(img_src, dim=0))
231
+ intrins.append(intrin)
232
+ extrinsics.append(torch.from_numpy(camera_to_lidar))
233
+ rots.append(rot)
234
+ trans.append(tran)
235
+ post_rots.append(post_rot)
236
+ post_trans.append(post_tran)
237
+
238
+ selected_cav_processed.update(
239
+ {
240
+ "image_inputs":
241
+ {
242
+ "imgs": torch.stack(imgs), # [Ncam, 3or4, H, W]
243
+ "intrins": torch.stack(intrins),
244
+ "extrinsics": torch.stack(extrinsics),
245
+ "rots": torch.stack(rots),
246
+ "trans": torch.stack(trans),
247
+ "post_rots": torch.stack(post_rots),
248
+ "post_trans": torch.stack(post_trans),
249
+ }
250
+ }
251
+ )
252
+
253
+ # anchor box
254
+ selected_cav_processed.update({"anchor_box": self.anchor_box})
255
+
256
+
257
+ if not online_eval:
258
+ # note the reference pose ego
259
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center([selected_cav_base],
260
+ ego_pose_clean)
261
+ selected_cav_processed.update(
262
+ {
263
+ "object_bbx_center": object_bbx_center[object_bbx_mask == 1],
264
+ "object_bbx_mask": object_bbx_mask,
265
+ "object_ids": object_ids,
266
+ }
267
+ )
268
+
269
+ selected_cav_processed.update(
270
+ {
271
+ 'transformation_matrix': transformation_matrix,
272
+ 'transformation_matrix_clean': transformation_matrix_clean
273
+ }
274
+ )
275
+
276
+ return selected_cav_processed
277
+
278
+ def __getitem__(self, idx, extra_source=None, data_dir=None):
279
+
280
+ if data_dir is not None:
281
+ extra_source=1
282
+
283
+ object_bbx_center_list = []
284
+ object_bbx_mask_list = []
285
+ object_id_dict = {}
286
+
287
+ object_bbx_center_list_single = []
288
+ object_bbx_mask_list_single = []
289
+
290
+
291
+ output_dict = {}
292
+ for tpe in ['all', 0, 1, 3]:
293
+ output_single_class = self.__getitem_single_class__(idx, tpe, extra_source, data_dir)
294
+ output_dict[tpe] = output_single_class
295
+ if tpe == 'all' and extra_source==None:
296
+ continue
297
+ elif tpe == 'all' and extra_source!=None:
298
+ break
299
+ object_bbx_center_list.append(output_single_class['ego']['object_bbx_center'])
300
+ object_bbx_mask_list.append(output_single_class['ego']['object_bbx_mask'])
301
+ if self.supervise_single:
302
+ object_bbx_center_list_single.append(output_single_class['ego']['single_object_bbx_center_torch'])
303
+ object_bbx_mask_list_single.append(output_single_class['ego']['single_object_bbx_mask_torch'])
304
+
305
+ object_id_dict[tpe] = output_single_class['ego']['object_ids']
306
+
307
+ if self.multiclass and extra_source==None:
308
+ output_dict['all']['ego']['object_bbx_center'] = np.stack(object_bbx_center_list, axis=0)
309
+ output_dict['all']['ego']['object_bbx_mask'] = np.stack(object_bbx_mask_list, axis=0)
310
+ if self.supervise_single:
311
+ output_dict['all']['ego']['single_object_bbx_center_torch'] = torch.stack(object_bbx_center_list_single, axis=1)
312
+ output_dict['all']['ego']['single_object_bbx_mask_torch'] = torch.stack(object_bbx_mask_list_single, axis=1)
313
+
314
+ output_dict['all']['ego']['object_ids'] = object_id_dict
315
+ # print('finish get item')
316
+ return output_dict['all']
317
+
318
+ def __getitem_single_class__(self, idx, tpe=None, extra_source=None, data_dir=None):
319
+
320
+ if extra_source is None and data_dir is None:
321
+ base_data_dict = self.retrieve_base_data(idx, tpe) ## {id:{'ego':True/False, 'params': {'lidar_pose','speed','vehicles','ego_pos',...}, 'lidar_np': array (N,4)}}
322
+ elif data_dir is not None:
323
+ base_data_dict = self.retrieve_base_data(idx=None, tpe=tpe, data_dir=data_dir)
324
+ elif extra_source is not None:
325
+ base_data_dict = self.retrieve_base_data(idx=None, tpe=tpe, extra_source=extra_source)
326
+
327
+ # base_data_dict = add_noise_data_dict(base_data_dict,self.params['noise_setting'])
328
+ base_data_dict = add_noise_data_dict_asymmetric(base_data_dict,self.params['noise_setting'])
329
+ processed_data_dict = OrderedDict()
330
+ processed_data_dict['ego'] = {}
331
+
332
+ ego_id = -1
333
+ ego_lidar_pose = []
334
+ ego_cav_base = None
335
+
336
+ # first find the ego vehicle's lidar pose
337
+ for cav_id, cav_content in base_data_dict.items():
338
+ if cav_content['ego']:
339
+ ego_id = cav_id
340
+ ego_lidar_pose = cav_content['params']['lidar_pose']
341
+ ego_cav_base = cav_content
342
+ break
343
+
344
+ assert cav_id == list(base_data_dict.keys())[
345
+ 0], "The first element in the OrderedDict must be ego"
346
+ assert ego_id != -1
347
+ assert len(ego_lidar_pose) > 0
348
+
349
+ agents_image_inputs = []
350
+ processed_features = []
351
+ object_stack = []
352
+ object_id_stack = []
353
+ single_label_list = []
354
+ single_object_bbx_center_list = []
355
+ single_object_bbx_mask_list = []
356
+ too_far = []
357
+ lidar_pose_list = []
358
+ lidar_pose_clean_list = []
359
+ cav_id_list = []
360
+ projected_lidar_clean_list = [] # disconet
361
+
362
+ if self.visualize or self.kd_flag:
363
+ projected_lidar_stack = []
364
+
365
+ # loop over all CAVs to process information
366
+ for cav_id, selected_cav_base in base_data_dict.items():
367
+ # check if the cav is within the communication range with ego
368
+ distance = \
369
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
370
+ ego_lidar_pose[0]) ** 2 + (
371
+ selected_cav_base['params'][
372
+ 'lidar_pose'][1] - ego_lidar_pose[
373
+ 1]) ** 2)
374
+
375
+ # if distance is too far, we will just skip this agent
376
+ if distance > self.params['comm_range']:
377
+ too_far.append(cav_id)
378
+ continue
379
+
380
+ lidar_pose_clean_list.append(selected_cav_base['params']['lidar_pose_clean'])
381
+ lidar_pose_list.append(selected_cav_base['params']['lidar_pose']) # 6dof pose
382
+ cav_id_list.append(cav_id)
383
+
384
+ for cav_id in too_far:
385
+ base_data_dict.pop(cav_id)
386
+
387
+ if self.box_align and str(idx) in self.stage1_result.keys(): # False
388
+ from opencood.models.sub_modules.box_align_v2 import box_alignment_relative_sample_np
389
+ stage1_content = self.stage1_result[str(idx)]
390
+ if stage1_content is not None:
391
+ all_agent_id_list = stage1_content['cav_id_list'] # include those out of range
392
+ all_agent_corners_list = stage1_content['pred_corner3d_np_list']
393
+ all_agent_uncertainty_list = stage1_content['uncertainty_np_list']
394
+
395
+ cur_agent_id_list = cav_id_list
396
+ cur_agent_pose = [base_data_dict[cav_id]['params']['lidar_pose'] for cav_id in cav_id_list]
397
+ cur_agnet_pose = np.array(cur_agent_pose)
398
+ cur_agent_in_all_agent = [all_agent_id_list.index(cur_agent) for cur_agent in cur_agent_id_list] # indexing current agent in `all_agent_id_list`
399
+
400
+ pred_corners_list = [np.array(all_agent_corners_list[cur_in_all_ind], dtype=np.float64)
401
+ for cur_in_all_ind in cur_agent_in_all_agent]
402
+ uncertainty_list = [np.array(all_agent_uncertainty_list[cur_in_all_ind], dtype=np.float64)
403
+ for cur_in_all_ind in cur_agent_in_all_agent]
404
+
405
+ if sum([len(pred_corners) for pred_corners in pred_corners_list]) != 0:
406
+ refined_pose = box_alignment_relative_sample_np(pred_corners_list,
407
+ cur_agnet_pose,
408
+ uncertainty_list=uncertainty_list,
409
+ **self.box_align_args)
410
+ cur_agnet_pose[:,[0,1,4]] = refined_pose
411
+
412
+ for i, cav_id in enumerate(cav_id_list):
413
+ lidar_pose_list[i] = cur_agnet_pose[i].tolist()
414
+ base_data_dict[cav_id]['params']['lidar_pose'] = cur_agnet_pose[i].tolist()
415
+
416
+ pairwise_t_matrix = \
417
+ get_pairwise_transformation(base_data_dict,
418
+ self.max_cav,
419
+ self.proj_first)
420
+
421
+ lidar_poses = np.array(lidar_pose_list).reshape(-1, 6) # [N_cav, 6]
422
+ lidar_poses_clean = np.array(lidar_pose_clean_list).reshape(-1, 6) # [N_cav, 6]
423
+
424
+ # merge preprocessed features from different cavs into the same dict
425
+ cav_num = len(cav_id_list)
426
+
427
+ # heterogeneous
428
+ if self.heterogeneous:
429
+ lidar_agent, camera_agent = self.selector.select_agent(idx)
430
+ lidar_agent = lidar_agent[:cav_num]
431
+ processed_data_dict['ego'].update({"lidar_agent": lidar_agent})
432
+
433
+ for _i, cav_id in enumerate(cav_id_list):
434
+ selected_cav_base = base_data_dict[cav_id]
435
+
436
+ # dynamic object center generator! for heterogeneous input
437
+ if (not self.visualize) and self.heterogeneous and lidar_agent[_i]:
438
+ self.generate_object_center = self.generate_object_center_lidar
439
+ elif (not self.visualize) and self.heterogeneous and (not lidar_agent[_i]):
440
+ self.generate_object_center = self.generate_object_center_camera
441
+
442
+ selected_cav_processed = self.get_item_single_car(
443
+ selected_cav_base,
444
+ ego_cav_base,
445
+ base_data_dict,
446
+ tpe,
447
+ cav_id,
448
+ extra_source!=None)
449
+
450
+ if extra_source==None:
451
+ object_stack.append(selected_cav_processed['object_bbx_center'])
452
+ object_id_stack += selected_cav_processed['object_ids']
453
+ if tpe == 'all':
454
+ if self.load_lidar_file:
455
+ processed_features.append(
456
+ selected_cav_processed['processed_features'])
457
+ if self.load_camera_file:
458
+ agents_image_inputs.append(
459
+ selected_cav_processed['image_inputs'])
460
+
461
+ if self.visualize or self.kd_flag:
462
+ projected_lidar_stack.append(
463
+ selected_cav_processed['projected_lidar'])
464
+
465
+ if self.supervise_single and extra_source==None:
466
+ single_label_list.append(selected_cav_processed['single_label_dict'])
467
+ single_object_bbx_center_list.append(selected_cav_processed['single_object_bbx_center'])
468
+ single_object_bbx_mask_list.append(selected_cav_processed['single_object_bbx_mask'])
469
+
470
+ # generate single view GT label
471
+ if self.supervise_single and extra_source==None:
472
+ single_label_dicts = {}
473
+ if tpe == 'all':
474
+ # unused label
475
+ if False:
476
+ single_label_dicts = self.post_processor.collate_batch(single_label_list)
477
+ single_object_bbx_center = torch.from_numpy(np.array(single_object_bbx_center_list))
478
+ single_object_bbx_mask = torch.from_numpy(np.array(single_object_bbx_mask_list))
479
+ processed_data_dict['ego'].update({
480
+ "single_label_dict_torch": single_label_dicts,
481
+ "single_object_bbx_center_torch": single_object_bbx_center,
482
+ "single_object_bbx_mask_torch": single_object_bbx_mask,
483
+ })
484
+
485
+ if self.kd_flag:
486
+ stack_lidar_np = np.vstack(projected_lidar_stack)
487
+ stack_lidar_np = mask_points_by_range(stack_lidar_np,
488
+ self.params['preprocess'][
489
+ 'cav_lidar_range'])
490
+ stack_feature_processed = self.pre_processor.preprocess(stack_lidar_np)
491
+ processed_data_dict['ego'].update({'teacher_processed_lidar':
492
+ stack_feature_processed})
493
+
494
+ if extra_source is None:
495
+ # exclude all repetitive objects
496
+ unique_indices = \
497
+ [object_id_stack.index(x) for x in set(object_id_stack)]
498
+ object_stack = np.vstack(object_stack)
499
+ object_stack = object_stack[unique_indices]
500
+
501
+ # make sure bounding boxes across all frames have the same number
502
+ object_bbx_center = \
503
+ np.zeros((self.params['postprocess']['max_num'], 7))
504
+ mask = np.zeros(self.params['postprocess']['max_num'])
505
+ object_bbx_center[:object_stack.shape[0], :] = object_stack
506
+ mask[:object_stack.shape[0]] = 1
507
+
508
+ processed_data_dict['ego'].update(
509
+ {'object_bbx_center': object_bbx_center, # (100,7)
510
+ 'object_bbx_mask': mask, # (100,)
511
+ 'object_ids': [object_id_stack[i] for i in unique_indices],
512
+ }
513
+ )
514
+
515
+
516
+ # generate targets label
517
+ label_dict = {}
518
+ if tpe == 'all':
519
+ # unused label
520
+ if False:
521
+ label_dict = \
522
+ self.post_processor.generate_label(
523
+ gt_box_center=object_bbx_center,
524
+ anchors=self.anchor_box,
525
+ mask=mask)
526
+
527
+ processed_data_dict['ego'].update(
528
+ {
529
+ 'anchor_box': self.anchor_box,
530
+ 'label_dict': label_dict,
531
+ 'cav_num': cav_num,
532
+ 'pairwise_t_matrix': pairwise_t_matrix,
533
+ 'lidar_poses_clean': lidar_poses_clean,
534
+ 'lidar_poses': lidar_poses})
535
+
536
+ if tpe == 'all':
537
+ if self.load_lidar_file:
538
+ merged_feature_dict = merge_features_to_dict(processed_features)
539
+ processed_data_dict['ego'].update({'processed_lidar': merged_feature_dict})
540
+ if self.load_camera_file:
541
+ merged_image_inputs_dict = merge_features_to_dict(agents_image_inputs, merge='stack')
542
+ processed_data_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
543
+
544
+ if self.visualize:
545
+ processed_data_dict['ego'].update({'origin_lidar':
546
+ # projected_lidar_stack})
547
+ np.vstack(
548
+ projected_lidar_stack)})
549
+ processed_data_dict['ego'].update({'lidar_len': [len(projected_lidar_stack[i]) for i in range(len(projected_lidar_stack))]})
550
+
551
+
552
+ processed_data_dict['ego'].update({'sample_idx': idx,
553
+ 'cav_id_list': cav_id_list})
554
+
555
+ img_front_list = []
556
+ img_left_list = []
557
+ img_right_list = []
558
+ BEV_list = []
559
+
560
+ if self.visualize:
561
+ for car_id in base_data_dict:
562
+ if not base_data_dict[car_id]['ego'] == True:
563
+ continue
564
+ if 'rgb_front' in base_data_dict[car_id] and 'rgb_left' in base_data_dict[car_id] and 'rgb_right' in base_data_dict[car_id] and 'BEV' in base_data_dict[car_id] :
565
+ img_front_list.append(base_data_dict[car_id]['rgb_front'])
566
+ img_left_list.append(base_data_dict[car_id]['rgb_left'])
567
+ img_right_list.append(base_data_dict[car_id]['rgb_right'])
568
+ BEV_list.append(base_data_dict[car_id]['BEV'])
569
+ processed_data_dict['ego'].update({'img_front': img_front_list,
570
+ 'img_left': img_left_list,
571
+ 'img_right': img_right_list,
572
+ 'BEV': BEV_list})
573
+ processed_data_dict['ego'].update({'scene_dict': base_data_dict['car_0']['scene_dict'],
574
+ 'frame_id': base_data_dict['car_0']['frame_id'],
575
+ })
576
+
577
+
578
+ return processed_data_dict
579
+
580
+
581
+ def collate_batch_train(self, batch, online_eval_only=False):
582
+ # Intermediate fusion is different the other two
583
+ output_dict = {'ego': {}}
584
+
585
+ object_bbx_center = []
586
+ object_bbx_mask = []
587
+ object_ids = []
588
+ processed_lidar_list = []
589
+ image_inputs_list = []
590
+ # used to record different scenario
591
+ record_len = []
592
+ label_dict_list = []
593
+ lidar_pose_list = []
594
+ origin_lidar = []
595
+ lidar_len = []
596
+ lidar_pose_clean_list = []
597
+
598
+ # heterogeneous
599
+ lidar_agent_list = []
600
+
601
+ # pairwise transformation matrix
602
+ pairwise_t_matrix_list = []
603
+
604
+ # disconet
605
+ teacher_processed_lidar_list = []
606
+
607
+ # image
608
+ img_front = []
609
+ img_left = []
610
+ img_right = []
611
+ BEV = []
612
+
613
+ dict_list = []
614
+
615
+ ### 2022.10.10 single gt ####
616
+ if self.supervise_single:
617
+ pos_equal_one_single = []
618
+ neg_equal_one_single = []
619
+ targets_single = []
620
+ object_bbx_center_single = []
621
+ object_bbx_mask_single = []
622
+
623
+ for i in range(len(batch)):
624
+ ego_dict = batch[i]['ego']
625
+ if not online_eval_only:
626
+ object_bbx_center.append(ego_dict['object_bbx_center'])
627
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
628
+ object_ids.append(ego_dict['object_ids'])
629
+ else:
630
+ object_ids.append(None)
631
+ lidar_pose_list.append(ego_dict['lidar_poses']) # ego_dict['lidar_pose'] is np.ndarray [N,6]
632
+ lidar_pose_clean_list.append(ego_dict['lidar_poses_clean'])
633
+ if self.load_lidar_file:
634
+ processed_lidar_list.append(ego_dict['processed_lidar'])
635
+ if self.load_camera_file:
636
+ image_inputs_list.append(ego_dict['image_inputs']) # different cav_num, ego_dict['image_inputs'] is dict.
637
+
638
+ record_len.append(ego_dict['cav_num'])
639
+ label_dict_list.append(ego_dict['label_dict'])
640
+ pairwise_t_matrix_list.append(ego_dict['pairwise_t_matrix'])
641
+
642
+ dict_list.append([ego_dict['scene_dict'], ego_dict['frame_id']])
643
+
644
+ if self.visualize:
645
+ origin_lidar.append(ego_dict['origin_lidar'])
646
+ lidar_len.append(ego_dict['lidar_len'])
647
+ if len(ego_dict['img_front']) > 0 and len(ego_dict['img_right']) > 0 and len(ego_dict['img_left']) > 0 and len(ego_dict['BEV']) > 0:
648
+ img_front.append(ego_dict['img_front'][0])
649
+ img_left.append(ego_dict['img_left'][0])
650
+ img_right.append(ego_dict['img_right'][0])
651
+ BEV.append(ego_dict['BEV'][0])
652
+
653
+
654
+ if self.kd_flag:
655
+ teacher_processed_lidar_list.append(ego_dict['teacher_processed_lidar'])
656
+
657
+ ### 2022.10.10 single gt ####
658
+ if self.supervise_single and not online_eval_only:
659
+ # unused label
660
+ if False:
661
+ pos_equal_one_single.append(ego_dict['single_label_dict_torch']['pos_equal_one'])
662
+ neg_equal_one_single.append(ego_dict['single_label_dict_torch']['neg_equal_one'])
663
+ targets_single.append(ego_dict['single_label_dict_torch']['targets'])
664
+ object_bbx_center_single.append(ego_dict['single_object_bbx_center_torch'])
665
+ object_bbx_mask_single.append(ego_dict['single_object_bbx_mask_torch'])
666
+
667
+ # heterogeneous
668
+ if self.heterogeneous:
669
+ lidar_agent_list.append(ego_dict['lidar_agent'])
670
+
671
+ # convert to numpy, (B, max_num, 7)
672
+ if not online_eval_only:
673
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
674
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
675
+ else:
676
+ object_bbx_center = None
677
+ object_bbx_mask = None
678
+
679
+ if self.load_lidar_file:
680
+ merged_feature_dict = merge_features_to_dict(processed_lidar_list)
681
+
682
+ if self.heterogeneous:
683
+ lidar_agent = np.concatenate(lidar_agent_list)
684
+ lidar_agent_idx = lidar_agent.nonzero()[0].tolist()
685
+ for k, v in merged_feature_dict.items(): # 'voxel_features' 'voxel_num_points' 'voxel_coords'
686
+ merged_feature_dict[k] = [v[index] for index in lidar_agent_idx]
687
+
688
+ if not self.heterogeneous or (self.heterogeneous and sum(lidar_agent) != 0):
689
+ processed_lidar_torch_dict = \
690
+ self.pre_processor.collate_batch(merged_feature_dict)
691
+ output_dict['ego'].update({'processed_lidar': processed_lidar_torch_dict})
692
+
693
+ if self.load_camera_file:
694
+ merged_image_inputs_dict = merge_features_to_dict(image_inputs_list, merge='cat')
695
+
696
+ if self.heterogeneous:
697
+ lidar_agent = np.concatenate(lidar_agent_list)
698
+ camera_agent = 1 - lidar_agent
699
+ camera_agent_idx = camera_agent.nonzero()[0].tolist()
700
+ if sum(camera_agent) != 0:
701
+ for k, v in merged_image_inputs_dict.items(): # 'imgs' 'rots' 'trans' ...
702
+ merged_image_inputs_dict[k] = torch.stack([v[index] for index in camera_agent_idx])
703
+
704
+ if not self.heterogeneous or (self.heterogeneous and sum(camera_agent) != 0):
705
+ output_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
706
+
707
+ record_len = torch.from_numpy(np.array(record_len, dtype=int))
708
+ lidar_pose = torch.from_numpy(np.concatenate(lidar_pose_list, axis=0))
709
+ lidar_pose_clean = torch.from_numpy(np.concatenate(lidar_pose_clean_list, axis=0))
710
+
711
+ # unused label
712
+ label_torch_dict = {}
713
+ if False:
714
+ label_torch_dict = \
715
+ self.post_processor.collate_batch(label_dict_list)
716
+ # for centerpoint
717
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
718
+ 'object_bbx_mask': object_bbx_mask})
719
+
720
+ # (B, max_cav)
721
+ pairwise_t_matrix = torch.from_numpy(np.array(pairwise_t_matrix_list))
722
+
723
+ # add pairwise_t_matrix to label dict
724
+ label_torch_dict['pairwise_t_matrix'] = pairwise_t_matrix
725
+ label_torch_dict['record_len'] = record_len
726
+
727
+
728
+ # object id is only used during inference, where batch size is 1.
729
+ # so here we only get the first element.
730
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
731
+ 'object_bbx_mask': object_bbx_mask,
732
+ 'record_len': record_len,
733
+ 'label_dict': label_torch_dict,
734
+ 'object_ids': object_ids[0],
735
+ 'pairwise_t_matrix': pairwise_t_matrix,
736
+ 'lidar_pose_clean': lidar_pose_clean,
737
+ 'lidar_pose': lidar_pose,
738
+ 'anchor_box': self.anchor_box_torch})
739
+
740
+
741
+ output_dict['ego'].update({'dict_list': dict_list})
742
+
743
+ if self.visualize:
744
+ origin_lidar = torch.from_numpy(np.array(origin_lidar))
745
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
746
+ lidar_len = np.array(lidar_len)
747
+ output_dict['ego'].update({'lidar_len': lidar_len})
748
+ output_dict['ego'].update({'img_front': img_front})
749
+ output_dict['ego'].update({'img_right': img_right})
750
+ output_dict['ego'].update({'img_left': img_left})
751
+ output_dict['ego'].update({'BEV': BEV})
752
+
753
+ if self.kd_flag:
754
+ teacher_processed_lidar_torch_dict = \
755
+ self.pre_processor.collate_batch(teacher_processed_lidar_list)
756
+ output_dict['ego'].update({'teacher_processed_lidar':teacher_processed_lidar_torch_dict})
757
+
758
+
759
+ if self.supervise_single and not online_eval_only:
760
+ output_dict['ego'].update({
761
+ "label_dict_single":{
762
+ # "pos_equal_one": torch.cat(pos_equal_one_single, dim=0),
763
+ # "neg_equal_one": torch.cat(neg_equal_one_single, dim=0),
764
+ # "targets": torch.cat(targets_single, dim=0),
765
+ # for centerpoint
766
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
767
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
768
+ },
769
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
770
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
771
+ })
772
+
773
+ if self.heterogeneous:
774
+ output_dict['ego'].update({
775
+ "lidar_agent_record": torch.from_numpy(np.concatenate(lidar_agent_list)) # [0,1,1,0,1...]
776
+ })
777
+
778
+ return output_dict
779
+
780
+ def collate_batch_test(self, batch, online_eval_only=False):
781
+
782
+ self.online_eval_only = online_eval_only
783
+
784
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
785
+ output_dict = self.collate_batch_train(batch, online_eval_only)
786
+ if output_dict is None:
787
+ return None
788
+
789
+ # check if anchor box in the batch
790
+ if batch[0]['ego']['anchor_box'] is not None:
791
+ output_dict['ego'].update({'anchor_box':
792
+ self.anchor_box_torch})
793
+
794
+ # save the transformation matrix (4, 4) to ego vehicle
795
+ # transformation is only used in post process (no use.)
796
+ # we all predict boxes in ego coord.
797
+ transformation_matrix_torch = \
798
+ torch.from_numpy(np.identity(4)).float()
799
+ transformation_matrix_clean_torch = \
800
+ torch.from_numpy(np.identity(4)).float()
801
+
802
+ output_dict['ego'].update({'transformation_matrix':
803
+ transformation_matrix_torch,
804
+ 'transformation_matrix_clean':
805
+ transformation_matrix_clean_torch,})
806
+
807
+ output_dict['ego'].update({
808
+ "sample_idx": batch[0]['ego']['sample_idx'],
809
+ "cav_id_list": batch[0]['ego']['cav_id_list']
810
+ })
811
+
812
+ return output_dict
813
+
814
+
815
+ def post_process(self, data_dict, output_dict):
816
+ """
817
+ Process the outputs of the model to 2D/3D bounding box.
818
+
819
+ Parameters
820
+ ----------
821
+ data_dict : dict
822
+ The dictionary containing the origin input data of model.
823
+
824
+ output_dict :dict
825
+ The dictionary containing the output of the model.
826
+
827
+ Returns
828
+ -------
829
+ pred_box_tensor : torch.Tensor
830
+ The tensor of prediction bounding box after NMS.
831
+ gt_box_tensor : torch.Tensor
832
+ The tensor of gt bounding box.
833
+ """
834
+ pred_box_tensor, pred_score = \
835
+ self.post_processor.post_process(data_dict, output_dict)
836
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
837
+
838
+ return pred_box_tensor, pred_score, gt_box_tensor
839
+
840
+ def post_process_multiclass(self, data_dict, output_dict, online_eval_only=False):
841
+ """
842
+ Process the outputs of the model to 2D/3D bounding box.
843
+
844
+ Parameters
845
+ ----------
846
+ data_dict : dict
847
+ The dictionary containing the origin input data of model.
848
+
849
+ output_dict :dict
850
+ The dictionary containing the output of the model.
851
+
852
+ Returns
853
+ -------
854
+ pred_box_tensor : torch.Tensor
855
+ The tensor of prediction bounding box after NMS.
856
+ gt_box_tensor : torch.Tensor
857
+ The tensor of gt bounding box.
858
+ """
859
+
860
+ if online_eval_only == False:
861
+ online_eval_only = self.online_eval_only
862
+
863
+ num_class = output_dict['ego']['cls_preds'].shape[1]
864
+
865
+
866
+ pred_box_tensor_list = []
867
+ pred_score_list = []
868
+ gt_box_tensor_list = []
869
+
870
+ num_list = [0,1,3]
871
+
872
+ for i in range(num_class):
873
+ data_dict_single = copy.deepcopy(data_dict)
874
+ output_dict_single = copy.deepcopy(output_dict)
875
+ if not online_eval_only:
876
+ data_dict_single['ego']['object_bbx_center'] = data_dict['ego']['object_bbx_center'][:,i,:,:]
877
+ data_dict_single['ego']['object_bbx_mask'] = data_dict['ego']['object_bbx_mask'][:,i,:]
878
+ data_dict_single['ego']['object_ids'] = data_dict['ego']['object_ids'][num_list[i]]
879
+
880
+ output_dict_single['ego']['cls_preds'] = output_dict['ego']['cls_preds'][:,i:i+1,:,:]
881
+ output_dict_single['ego']['reg_preds'] = output_dict['ego']['reg_preds_multiclass'][:,i,:,:]
882
+
883
+ pred_box_tensor, pred_score = \
884
+ self.post_processor.post_process(data_dict_single, output_dict_single)
885
+
886
+ if not online_eval_only:
887
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict_single)
888
+ else:
889
+ gt_box_tensor = None
890
+
891
+ pred_box_tensor_list.append(pred_box_tensor)
892
+ pred_score_list.append(pred_score)
893
+ gt_box_tensor_list.append(gt_box_tensor)
894
+
895
+ return pred_box_tensor_list, pred_score_list, gt_box_tensor_list
896
+
897
+ return EarlymulticlassFusionDataset
898
+
899
+
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/intermediate_2stage_fusion_dataset.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # intermediate fusion dataset
2
+ import random
3
+ import math
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+ import torch
7
+ import copy
8
+ from icecream import ic
9
+ from PIL import Image
10
+ import pickle as pkl
11
+ from opencood.utils import box_utils as box_utils
12
+ from opencood.data_utils.pre_processor import build_preprocessor
13
+ from opencood.data_utils.post_processor import build_postprocessor
14
+ from opencood.utils.camera_utils import (
15
+ sample_augmentation,
16
+ img_transform,
17
+ normalize_img,
18
+ img_to_tensor,
19
+ )
20
+ from opencood.utils.common_utils import merge_features_to_dict
21
+ from opencood.utils.transformation_utils import x1_to_x2, x_to_world, get_pairwise_transformation
22
+ from opencood.utils.pose_utils import add_noise_data_dict
23
+ from opencood.utils.pcd_utils import (
24
+ mask_points_by_range,
25
+ mask_ego_points,
26
+ shuffle_points,
27
+ downsample_lidar_minimum,
28
+ )
29
+
30
+ def getIntermediate2stageFusionDataset(cls):
31
+ """
32
+ cls: the Basedataset.
33
+ """
34
+ class Intermediate2stageFusionDataset(cls):
35
+ def __init__(self, params, visualize, train=True):
36
+ super().__init__(params, visualize, train)
37
+ # intermediate and supervise single
38
+ self.supervise_single = True if ('supervise_single' in params['model']['args'] and params['model']['args']['supervise_single']) \
39
+ else False
40
+ # it is assert to be False but by default it will load single label for 1-stage training.
41
+ assert self.supervise_single is False
42
+
43
+ self.proj_first = False if 'proj_first' not in params['fusion']['args']\
44
+ else params['fusion']['args']['proj_first']
45
+
46
+ self.anchor_box = self.post_processor.generate_anchor_box()
47
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
48
+
49
+ self.heterogeneous = False
50
+ if 'heter' in params:
51
+ self.heterogeneous = True
52
+
53
+ def get_item_single_car(self, selected_cav_base, ego_cav_base):
54
+ """
55
+ Process a single CAV's information for the train/test pipeline.
56
+
57
+
58
+ Parameters
59
+ ----------
60
+ selected_cav_base : dict
61
+ The dictionary contains a single CAV's raw information.
62
+ including 'params', 'camera_data'
63
+ ego_pose : list, length 6
64
+ The ego vehicle lidar pose under world coordinate.
65
+ ego_pose_clean : list, length 6
66
+ only used for gt box generation
67
+
68
+ Returns
69
+ -------
70
+ selected_cav_processed : dict
71
+ The dictionary contains the cav's processed information.
72
+ """
73
+ selected_cav_processed = {}
74
+ ego_pose, ego_pose_clean = ego_cav_base['params']['lidar_pose'], ego_cav_base['params']['lidar_pose_clean']
75
+
76
+ # calculate the transformation matrix
77
+ transformation_matrix = \
78
+ x1_to_x2(selected_cav_base['params']['lidar_pose'],
79
+ ego_pose) # T_ego_cav
80
+ transformation_matrix_clean = \
81
+ x1_to_x2(selected_cav_base['params']['lidar_pose_clean'],
82
+ ego_pose_clean)
83
+
84
+ # lidar
85
+ if self.load_lidar_file or self.visualize:
86
+ # process lidar
87
+ lidar_np = selected_cav_base['lidar_np']
88
+ lidar_np = shuffle_points(lidar_np)
89
+ # remove points that hit itself
90
+ lidar_np = mask_ego_points(lidar_np)
91
+
92
+ # no projected lidar
93
+ no_project_lidar = copy.deepcopy(lidar_np)
94
+
95
+ # project the lidar to ego space
96
+ # x,y,z in ego space
97
+ projected_lidar = \
98
+ box_utils.project_points_by_matrix_torch(lidar_np[:, :3],
99
+ transformation_matrix)
100
+ if self.proj_first: #
101
+ lidar_np[:, :3] = projected_lidar
102
+
103
+ if self.visualize:
104
+ # filter lidar
105
+ selected_cav_processed.update({'projected_lidar': projected_lidar})
106
+
107
+ processed_lidar = self.pre_processor.preprocess(lidar_np)
108
+ selected_cav_processed.update({'projected_lidar': projected_lidar,
109
+ 'no_projected_lidar': no_project_lidar,
110
+ 'processed_features': processed_lidar})
111
+
112
+ # generate targets label single GT, note the reference pose is itself.
113
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center(
114
+ [selected_cav_base], selected_cav_base['params']['lidar_pose']
115
+ )
116
+ label_dict = self.post_processor.generate_label(
117
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
118
+ )
119
+ selected_cav_processed.update({"object_bbx_center_no_coop": object_bbx_center[object_bbx_mask==1],
120
+ "single_label_dict": label_dict})
121
+
122
+ # camera
123
+ if self.load_camera_file:
124
+ camera_data_list = selected_cav_base["camera_data"]
125
+
126
+ params = selected_cav_base["params"]
127
+ imgs = []
128
+ rots = []
129
+ trans = []
130
+ intrins = []
131
+ post_rots = []
132
+ post_trans = []
133
+
134
+ for idx, img in enumerate(camera_data_list):
135
+ camera_to_lidar, camera_intrinsic = self.get_ext_int(params, idx)
136
+
137
+ intrin = torch.from_numpy(camera_intrinsic)
138
+ rot = torch.from_numpy(
139
+ camera_to_lidar[:3, :3]
140
+ ) # R_wc, we consider world-coord is the lidar-coord
141
+ tran = torch.from_numpy(camera_to_lidar[:3, 3]) # T_wc
142
+
143
+ post_rot = torch.eye(2)
144
+ post_tran = torch.zeros(2)
145
+
146
+ img_src = [img]
147
+
148
+ # depth
149
+ if self.load_depth_file:
150
+ depth_img = selected_cav_base["depth_data"][idx]
151
+ img_src.append(depth_img)
152
+ else:
153
+ depth_img = None
154
+
155
+ # data augmentation
156
+ resize, resize_dims, crop, flip, rotate = sample_augmentation(
157
+ self.data_aug_conf, self.train
158
+ )
159
+ img_src, post_rot2, post_tran2 = img_transform(
160
+ img_src,
161
+ post_rot,
162
+ post_tran,
163
+ resize=resize,
164
+ resize_dims=resize_dims,
165
+ crop=crop,
166
+ flip=flip,
167
+ rotate=rotate,
168
+ )
169
+ # for convenience, make augmentation matrices 3x3
170
+ post_tran = torch.zeros(3)
171
+ post_rot = torch.eye(3)
172
+ post_tran[:2] = post_tran2
173
+ post_rot[:2, :2] = post_rot2
174
+
175
+ # decouple RGB and Depth
176
+
177
+ img_src[0] = normalize_img(img_src[0])
178
+ if self.load_depth_file:
179
+ img_src[1] = img_to_tensor(img_src[1]) * 255
180
+
181
+ imgs.append(torch.cat(img_src, dim=0))
182
+ intrins.append(intrin)
183
+ rots.append(rot)
184
+ trans.append(tran)
185
+ post_rots.append(post_rot)
186
+ post_trans.append(post_tran)
187
+
188
+ selected_cav_processed.update(
189
+ {
190
+ "image_inputs":
191
+ {
192
+ "imgs": torch.stack(imgs), # [Ncam, 3or4, H, W]
193
+ "intrins": torch.stack(intrins),
194
+ "rots": torch.stack(rots),
195
+ "trans": torch.stack(trans),
196
+ "post_rots": torch.stack(post_rots),
197
+ "post_trans": torch.stack(post_trans),
198
+ }
199
+ }
200
+ )
201
+
202
+ # anchor box
203
+ selected_cav_processed.update({"anchor_box": self.anchor_box})
204
+
205
+ # note the reference pose ego
206
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center([selected_cav_base],
207
+ ego_pose_clean)
208
+
209
+ selected_cav_processed.update(
210
+ {
211
+ "object_bbx_center": object_bbx_center[object_bbx_mask == 1],
212
+ "object_bbx_mask": object_bbx_mask,
213
+ "object_ids": object_ids,
214
+ 'transformation_matrix': transformation_matrix,
215
+ 'transformation_matrix_clean': transformation_matrix_clean
216
+ }
217
+ )
218
+
219
+
220
+ return selected_cav_processed
221
+
222
+ def __getitem__(self, idx):
223
+ base_data_dict = self.retrieve_base_data(idx)
224
+ base_data_dict = add_noise_data_dict(base_data_dict,self.params['noise_setting'])
225
+
226
+ processed_data_dict = OrderedDict()
227
+ processed_data_dict['ego'] = {}
228
+
229
+ ego_id = -1
230
+ ego_lidar_pose = []
231
+ ego_cav_base = None
232
+
233
+ # first find the ego vehicle's lidar pose
234
+ for cav_id, cav_content in base_data_dict.items():
235
+ if cav_content['ego']:
236
+ ego_id = cav_id
237
+ ego_lidar_pose = cav_content['params']['lidar_pose']
238
+ ego_cav_base = cav_content
239
+ break
240
+
241
+ assert cav_id == list(base_data_dict.keys())[
242
+ 0], "The first element in the OrderedDict must be ego"
243
+ assert ego_id != -1
244
+ assert len(ego_lidar_pose) > 0
245
+
246
+ agents_image_inputs = []
247
+ processed_features = []
248
+ object_stack = []
249
+ object_id_stack = []
250
+ single_label_list = []
251
+ too_far = []
252
+ lidar_pose_list = []
253
+ lidar_pose_clean_list = []
254
+ cav_id_list = []
255
+
256
+ projected_lidar_stack = []
257
+ no_projected_lidar_stack = []
258
+
259
+ vsa_lidar_stack = []
260
+
261
+ if self.visualize:
262
+ projected_lidar_stack = []
263
+
264
+ # loop over all CAVs to process information
265
+ for cav_id, selected_cav_base in base_data_dict.items():
266
+ # check if the cav is within the communication range with ego
267
+ distance = \
268
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
269
+ ego_lidar_pose[0]) ** 2 + (
270
+ selected_cav_base['params'][
271
+ 'lidar_pose'][1] - ego_lidar_pose[
272
+ 1]) ** 2)
273
+
274
+ # if distance is too far, we will just skip this agent
275
+ if distance > self.params['comm_range']:
276
+ too_far.append(cav_id)
277
+ continue
278
+
279
+ lidar_pose_clean_list.append(selected_cav_base['params']['lidar_pose_clean'])
280
+ lidar_pose_list.append(selected_cav_base['params']['lidar_pose']) # 6dof pose
281
+ cav_id_list.append(cav_id)
282
+
283
+ for cav_id in too_far:
284
+ base_data_dict.pop(cav_id)
285
+
286
+
287
+ pairwise_t_matrix = \
288
+ get_pairwise_transformation(base_data_dict,
289
+ self.max_cav,
290
+ self.proj_first)
291
+
292
+ lidar_poses = np.array(lidar_pose_list).reshape(-1, 6) # [N_cav, 6]
293
+ lidar_poses_clean = np.array(lidar_pose_clean_list).reshape(-1, 6) # [N_cav, 6]
294
+
295
+ # merge preprocessed features from different cavs into the same dict
296
+ cav_num = len(cav_id_list)
297
+
298
+ # heterogeneous
299
+ if self.heterogeneous:
300
+ lidar_agent, camera_agent = self.selector.select_agent(idx)
301
+ lidar_agent = lidar_agent[:cav_num]
302
+ processed_data_dict['ego'].update({"lidar_agent": lidar_agent})
303
+
304
+
305
+ for _i, cav_id in enumerate(cav_id_list):
306
+ selected_cav_base = base_data_dict[cav_id]
307
+
308
+ # dynamic object center generator! for heterogeneous input.
309
+ if (not self.visualize) and self.heterogeneous and lidar_agent[_i]:
310
+ self.generate_object_center = self.generate_object_center_lidar
311
+ elif (not self.visualize) and self.heterogeneous and (not lidar_agent[_i]):
312
+ self.generate_object_center = self.generate_object_center_camera
313
+
314
+ selected_cav_processed = self.get_item_single_car(
315
+ selected_cav_base,
316
+ ego_cav_base)
317
+
318
+ object_stack.append(selected_cav_processed['object_bbx_center'])
319
+ object_id_stack += selected_cav_processed['object_ids']
320
+
321
+ if self.load_lidar_file:
322
+ processed_features.append(
323
+ selected_cav_processed['processed_features'])
324
+ if self.proj_first:
325
+ vsa_lidar_stack.append(selected_cav_processed['projected_lidar'])
326
+ else:
327
+ vsa_lidar_stack.append(selected_cav_processed['no_projected_lidar'])
328
+
329
+ if self.load_camera_file:
330
+ agents_image_inputs.append(
331
+ selected_cav_processed['image_inputs'])
332
+
333
+ if self.visualize:
334
+ projected_lidar_stack.append(
335
+ selected_cav_processed['projected_lidar'])
336
+
337
+ single_label_list.append(selected_cav_processed['single_label_dict'])
338
+
339
+ # generate single view label (no coop) label
340
+ label_dict_no_coop = single_label_list # [{cav1_label}, {cav2_label}...]
341
+
342
+
343
+ # exclude all repetitive objects
344
+ unique_indices = \
345
+ [object_id_stack.index(x) for x in set(object_id_stack)]
346
+ object_stack = np.vstack(object_stack)
347
+ object_stack = object_stack[unique_indices]
348
+
349
+ # make sure bounding boxes across all frames have the same number
350
+ object_bbx_center = \
351
+ np.zeros((self.params['postprocess']['max_num'], 7))
352
+ mask = np.zeros(self.params['postprocess']['max_num'])
353
+ object_bbx_center[:object_stack.shape[0], :] = object_stack
354
+ mask[:object_stack.shape[0]] = 1
355
+
356
+ if self.load_lidar_file:
357
+ merged_feature_dict = merge_features_to_dict(processed_features)
358
+ processed_data_dict['ego'].update({'processed_lidar': merged_feature_dict,
359
+ 'vsa_lidar': vsa_lidar_stack})
360
+ if self.load_camera_file:
361
+ merged_image_inputs_dict = merge_features_to_dict(agents_image_inputs, merge='stack')
362
+ processed_data_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
363
+
364
+ # generate targets label
365
+ label_dict_coop = \
366
+ self.post_processor.generate_label(
367
+ gt_box_center=object_bbx_center,
368
+ anchors=self.anchor_box,
369
+ mask=mask)
370
+
371
+ label_dict = {
372
+ 'stage1': label_dict_no_coop, # list
373
+ 'stage2': label_dict_coop # dict
374
+ }
375
+
376
+ processed_data_dict['ego'].update(
377
+ {'object_bbx_center': object_bbx_center,
378
+ 'object_bbx_mask': mask,
379
+ 'object_ids': [object_id_stack[i] for i in unique_indices],
380
+ 'anchor_box': self.anchor_box,
381
+ 'label_dict': label_dict,
382
+ 'cav_num': cav_num,
383
+ 'pairwise_t_matrix': pairwise_t_matrix,
384
+ 'lidar_poses_clean': lidar_poses_clean,
385
+ 'lidar_poses': lidar_poses})
386
+
387
+
388
+ if self.visualize:
389
+ processed_data_dict['ego'].update({'origin_lidar':
390
+ np.vstack(
391
+ projected_lidar_stack)})
392
+
393
+
394
+ processed_data_dict['ego'].update({'sample_idx': idx,
395
+ 'cav_id_list': cav_id_list})
396
+
397
+ return processed_data_dict
398
+
399
+
400
+ def collate_batch_train(self, batch):
401
+ # Intermediate fusion is different the other two
402
+ output_dict = {'ego': {}}
403
+
404
+ object_bbx_center = []
405
+ object_bbx_mask = []
406
+ object_ids = []
407
+ processed_lidar_list = []
408
+ image_inputs_list = []
409
+ # used to record different scenario
410
+ record_len = []
411
+ label_dict_no_coop_batch_list = []
412
+ label_dict_list = []
413
+ lidar_pose_list = []
414
+ origin_lidar = []
415
+ vsa_lidar = []
416
+ lidar_pose_clean_list = []
417
+
418
+ # pairwise transformation matrix
419
+ pairwise_t_matrix_list = []
420
+
421
+ # heterogeneous
422
+ lidar_agent_list = []
423
+
424
+ for i in range(len(batch)):
425
+ ego_dict = batch[i]['ego']
426
+ object_bbx_center.append(ego_dict['object_bbx_center'])
427
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
428
+ object_ids.append(ego_dict['object_ids'])
429
+ lidar_pose_list.append(ego_dict['lidar_poses']) # ego_dict['lidar_pose'] is np.ndarray [N,6]
430
+ lidar_pose_clean_list.append(ego_dict['lidar_poses_clean'])
431
+ if self.load_lidar_file:
432
+ processed_lidar_list.append(ego_dict['processed_lidar'])
433
+ vsa_lidar.append(ego_dict['vsa_lidar'])
434
+ if self.load_camera_file:
435
+ image_inputs_list.append(ego_dict['image_inputs']) # different cav_num, ego_dict['image_inputs'] is dict.
436
+
437
+ record_len.append(ego_dict['cav_num'])
438
+ label_dict_no_coop_batch_list.append(ego_dict['label_dict']['stage1'])
439
+ label_dict_list.append(ego_dict['label_dict']['stage2'])
440
+
441
+ pairwise_t_matrix_list.append(ego_dict['pairwise_t_matrix'])
442
+
443
+ if self.visualize:
444
+ origin_lidar.append(ego_dict['origin_lidar'])
445
+
446
+ # heterogeneous
447
+ if self.heterogeneous:
448
+ lidar_agent_list.append(ego_dict['lidar_agent'])
449
+
450
+
451
+ # convert to numpy, (B, max_num, 7)
452
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
453
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
454
+
455
+ # example: {'voxel_features':[np.array([1,2,3]]),
456
+ # np.array([3,5,6]), ...]}
457
+ if self.load_lidar_file:
458
+ merged_feature_dict = merge_features_to_dict(processed_lidar_list)
459
+ # [sum(record_len), C, H, W]
460
+ if self.heterogeneous:
461
+ lidar_agent = np.concatenate(lidar_agent_list)
462
+ lidar_agent_idx = lidar_agent.nonzero()[0].tolist()
463
+ for k, v in merged_feature_dict.items(): # 'voxel_features' 'voxel_num_points' 'voxel_coords'
464
+ merged_feature_dict[k] = [v[index] for index in lidar_agent_idx]
465
+
466
+ if not self.heterogeneous or (self.heterogeneous and sum(lidar_agent) != 0):
467
+ processed_lidar_torch_dict = \
468
+ self.pre_processor.collate_batch(merged_feature_dict)
469
+ output_dict['ego'].update({'processed_lidar': processed_lidar_torch_dict})
470
+
471
+ if self.load_camera_file:
472
+ merged_image_inputs_dict = merge_features_to_dict(image_inputs_list, merge='cat')
473
+
474
+ if self.heterogeneous:
475
+ lidar_agent = np.concatenate(lidar_agent_list)
476
+ camera_agent = 1 - lidar_agent
477
+ camera_agent_idx = camera_agent.nonzero()[0].tolist()
478
+ if sum(camera_agent) != 0:
479
+ for k, v in merged_image_inputs_dict.items(): # 'imgs' 'rots' 'trans' ...
480
+ merged_image_inputs_dict[k] = torch.stack([v[index] for index in camera_agent_idx])
481
+
482
+ if not self.heterogeneous or (self.heterogeneous and sum(camera_agent) != 0):
483
+ output_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
484
+
485
+ record_len = torch.from_numpy(np.array(record_len, dtype=int))
486
+ lidar_pose = torch.from_numpy(np.concatenate(lidar_pose_list, axis=0))
487
+ lidar_pose_clean = torch.from_numpy(np.concatenate(lidar_pose_clean_list, axis=0))
488
+ label_dict_no_coop_cavs_batch_list = [label_dict for label_dict_cavs_list in
489
+ label_dict_no_coop_batch_list for label_dict in
490
+ label_dict_cavs_list]
491
+ label_no_coop_torch_dict = \
492
+ self.post_processor.collate_batch(label_dict_no_coop_cavs_batch_list)
493
+
494
+ label_torch_dict = \
495
+ self.post_processor.collate_batch(label_dict_list)
496
+
497
+ # (B, max_cav)
498
+ pairwise_t_matrix = torch.from_numpy(np.array(pairwise_t_matrix_list))
499
+
500
+ # add pairwise_t_matrix to label dict
501
+ label_torch_dict['pairwise_t_matrix'] = pairwise_t_matrix
502
+ label_torch_dict['record_len'] = record_len
503
+
504
+ # object id is only used during inference, where batch size is 1.
505
+ # so here we only get the first element.
506
+ output_dict['ego'].update({ 'object_bbx_center': object_bbx_center,
507
+ 'object_bbx_mask': object_bbx_mask,
508
+ 'record_len': record_len,
509
+ 'label_dict': {
510
+ 'stage1': label_no_coop_torch_dict,
511
+ 'stage2': label_torch_dict,
512
+ },
513
+ 'object_ids': object_ids[0],
514
+ 'pairwise_t_matrix': pairwise_t_matrix,
515
+ 'lidar_pose_clean': lidar_pose_clean,
516
+ 'lidar_pose': lidar_pose,
517
+ 'proj_first': self.proj_first,
518
+ 'anchor_box': self.anchor_box_torch})
519
+
520
+ if self.load_lidar_file:
521
+ coords = []
522
+ idx = 0
523
+ for b in range(len(batch)):
524
+ for points in vsa_lidar[b]:
525
+ assert len(points) != 0
526
+ coor_pad = np.pad(points, ((0, 0), (1, 0)),
527
+ mode="constant", constant_values=idx)
528
+ coords.append(coor_pad)
529
+ idx += 1
530
+ origin_lidar_for_vsa = np.concatenate(coords, axis=0)
531
+ origin_lidar_for_vsa = torch.from_numpy(origin_lidar_for_vsa)
532
+ output_dict['ego'].update({'origin_lidar_for_vsa': origin_lidar_for_vsa})
533
+
534
+ if self.visualize:
535
+ origin_lidar = \
536
+ np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
537
+ origin_lidar = torch.from_numpy(origin_lidar)
538
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
539
+
540
+ if self.heterogeneous:
541
+ output_dict['ego'].update({
542
+ "lidar_agent_record": torch.from_numpy(np.concatenate(lidar_agent_list)) # [0,1,1,0,1...]
543
+ })
544
+
545
+ return output_dict
546
+
547
+ def collate_batch_test(self, batch):
548
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
549
+ output_dict = self.collate_batch_train(batch)
550
+ if output_dict is None:
551
+ return None
552
+
553
+ # check if anchor box in the batch
554
+ output_dict['ego'].update({'anchor_box': self.anchor_box_torch})
555
+
556
+ # save the transformation matrix (4, 4) to ego vehicle
557
+ # transformation is only used in post process (no use.)
558
+ # we all predict boxes in ego coord.
559
+ transformation_matrix_torch = \
560
+ torch.from_numpy(np.identity(4)).float()
561
+ transformation_matrix_clean_torch = \
562
+ torch.from_numpy(np.identity(4)).float()
563
+
564
+ output_dict['ego'].update({'transformation_matrix':
565
+ transformation_matrix_torch,
566
+ 'transformation_matrix_clean':
567
+ transformation_matrix_clean_torch,})
568
+
569
+ output_dict['ego'].update({
570
+ "sample_idx": batch[0]['ego']['sample_idx'],
571
+ "cav_id_list": batch[0]['ego']['cav_id_list']
572
+ })
573
+
574
+ return output_dict
575
+
576
+
577
+ def post_process(self, data_dict, output_dict):
578
+ """
579
+ Process the outputs of the model to 2D/3D bounding box.
580
+
581
+ Parameters
582
+ ----------
583
+ data_dict : dict
584
+ The dictionary containing the origin input data of model.
585
+
586
+ output_dict :dict
587
+ The dictionary containing the output of the model.
588
+
589
+ Returns
590
+ -------
591
+ pred_box_tensor : torch.Tensor
592
+ The tensor of prediction bounding box after NMS.
593
+ gt_box_tensor : torch.Tensor
594
+ The tensor of gt bounding box.
595
+ """
596
+ pred_box_tensor, pred_score = \
597
+ self.post_processor.post_process(data_dict, output_dict)
598
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
599
+
600
+ return pred_box_tensor, pred_score, gt_box_tensor
601
+
602
+
603
+ return Intermediate2stageFusionDataset
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/intermediate_fusion_dataset.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # intermediate fusion dataset
2
+ import random
3
+ import math
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+ import torch
7
+ import copy
8
+ from icecream import ic
9
+ from PIL import Image
10
+ import pickle as pkl
11
+ from opencood.utils import box_utils as box_utils
12
+ from opencood.data_utils.pre_processor import build_preprocessor
13
+ from opencood.data_utils.post_processor import build_postprocessor
14
+ from opencood.utils.camera_utils import (
15
+ sample_augmentation,
16
+ img_transform,
17
+ normalize_img,
18
+ img_to_tensor,
19
+ )
20
+ from opencood.utils.common_utils import merge_features_to_dict
21
+ from opencood.utils.transformation_utils import x1_to_x2, x_to_world, get_pairwise_transformation
22
+ from opencood.utils.pose_utils import add_noise_data_dict
23
+ from opencood.utils.pcd_utils import (
24
+ mask_points_by_range,
25
+ mask_ego_points,
26
+ shuffle_points,
27
+ downsample_lidar_minimum,
28
+ )
29
+ from opencood.utils.common_utils import read_json
30
+
31
+
32
+ def getIntermediateFusionDataset(cls):
33
+ """
34
+ cls: the Basedataset.
35
+ """
36
+ class IntermediateFusionDataset(cls):
37
+ def __init__(self, params, visualize, train=True):
38
+ super().__init__(params, visualize, train)
39
+ # intermediate and supervise single
40
+ self.supervise_single = True if ('supervise_single' in params['model']['args'] and params['model']['args']['supervise_single']) \
41
+ else False
42
+ self.proj_first = False if 'proj_first' not in params['fusion']['args']\
43
+ else params['fusion']['args']['proj_first']
44
+
45
+ self.anchor_box = self.post_processor.generate_anchor_box()
46
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
47
+
48
+ self.heterogeneous = False
49
+ if 'heter' in params:
50
+ self.heterogeneous = True
51
+
52
+ self.kd_flag = params.get('kd_flag', False)
53
+
54
+ self.box_align = False
55
+ if "box_align" in params:
56
+ self.box_align = True
57
+ self.stage1_result_path = params['box_align']['train_result'] if train else params['box_align']['val_result']
58
+ self.stage1_result = read_json(self.stage1_result_path)
59
+ self.box_align_args = params['box_align']['args']
60
+
61
+
62
+
63
+
64
+ def get_item_single_car(self, selected_cav_base, ego_cav_base):
65
+ """
66
+ Process a single CAV's information for the train/test pipeline.
67
+
68
+
69
+ Parameters
70
+ ----------
71
+ selected_cav_base : dict
72
+ The dictionary contains a single CAV's raw information.
73
+ including 'params', 'camera_data'
74
+ ego_pose : list, length 6
75
+ The ego vehicle lidar pose under world coordinate.
76
+ ego_pose_clean : list, length 6
77
+ only used for gt box generation
78
+
79
+ Returns
80
+ -------
81
+ selected_cav_processed : dict
82
+ The dictionary contains the cav's processed information.
83
+ """
84
+ selected_cav_processed = {}
85
+ ego_pose, ego_pose_clean = ego_cav_base['params']['lidar_pose'], ego_cav_base['params']['lidar_pose_clean']
86
+
87
+ # calculate the transformation matrix
88
+ transformation_matrix = \
89
+ x1_to_x2(selected_cav_base['params']['lidar_pose'],
90
+ ego_pose) # T_ego_cav
91
+ transformation_matrix_clean = \
92
+ x1_to_x2(selected_cav_base['params']['lidar_pose_clean'],
93
+ ego_pose_clean)
94
+
95
+ # lidar
96
+ if self.load_lidar_file or self.visualize:
97
+ # process lidar
98
+ lidar_np = selected_cav_base['lidar_np']
99
+ lidar_np = shuffle_points(lidar_np)
100
+ # remove points that hit itself
101
+ lidar_np = mask_ego_points(lidar_np)
102
+ # project the lidar to ego space
103
+ # x,y,z in ego space
104
+ projected_lidar = \
105
+ box_utils.project_points_by_matrix_torch(lidar_np[:, :3],
106
+ transformation_matrix)
107
+ if self.proj_first:
108
+ lidar_np[:, :3] = projected_lidar
109
+
110
+ if self.visualize:
111
+ # filter lidar
112
+ selected_cav_processed.update({'projected_lidar': projected_lidar})
113
+
114
+ if self.kd_flag:
115
+ lidar_proj_np = copy.deepcopy(lidar_np)
116
+ lidar_proj_np[:,:3] = projected_lidar
117
+
118
+ selected_cav_processed.update({'projected_lidar': lidar_proj_np})
119
+
120
+ processed_lidar = self.pre_processor.preprocess(lidar_np)
121
+ selected_cav_processed.update({'processed_features': processed_lidar})
122
+
123
+ # generate targets label single GT, note the reference pose is itself.
124
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center(
125
+ [selected_cav_base], selected_cav_base['params']['lidar_pose']
126
+ )
127
+ label_dict = self.post_processor.generate_label(
128
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
129
+ )
130
+ selected_cav_processed.update({
131
+ "single_label_dict": label_dict,
132
+ "single_object_bbx_center": object_bbx_center,
133
+ "single_object_bbx_mask": object_bbx_mask})
134
+
135
+ # camera
136
+ if self.load_camera_file:
137
+ camera_data_list = selected_cav_base["camera_data"]
138
+
139
+ params = selected_cav_base["params"]
140
+ imgs = []
141
+ rots = []
142
+ trans = []
143
+ intrins = []
144
+ extrinsics = []
145
+ post_rots = []
146
+ post_trans = []
147
+
148
+ for idx, img in enumerate(camera_data_list):
149
+ camera_to_lidar, camera_intrinsic = self.get_ext_int(params, idx)
150
+
151
+ intrin = torch.from_numpy(camera_intrinsic)
152
+ rot = torch.from_numpy(
153
+ camera_to_lidar[:3, :3]
154
+ ) # R_wc, we consider world-coord is the lidar-coord
155
+ tran = torch.from_numpy(camera_to_lidar[:3, 3]) # T_wc
156
+
157
+ post_rot = torch.eye(2)
158
+ post_tran = torch.zeros(2)
159
+
160
+ img_src = [img]
161
+
162
+ # depth
163
+ if self.load_depth_file:
164
+ depth_img = selected_cav_base["depth_data"][idx]
165
+ img_src.append(depth_img)
166
+ else:
167
+ depth_img = None
168
+
169
+ # data augmentation
170
+ resize, resize_dims, crop, flip, rotate = sample_augmentation(
171
+ self.data_aug_conf, self.train
172
+ )
173
+ img_src, post_rot2, post_tran2 = img_transform(
174
+ img_src,
175
+ post_rot,
176
+ post_tran,
177
+ resize=resize,
178
+ resize_dims=resize_dims,
179
+ crop=crop,
180
+ flip=flip,
181
+ rotate=rotate,
182
+ )
183
+ # for convenience, make augmentation matrices 3x3
184
+ post_tran = torch.zeros(3)
185
+ post_rot = torch.eye(3)
186
+ post_tran[:2] = post_tran2
187
+ post_rot[:2, :2] = post_rot2
188
+
189
+ # decouple RGB and Depth
190
+
191
+ img_src[0] = normalize_img(img_src[0])
192
+ if self.load_depth_file:
193
+ img_src[1] = img_to_tensor(img_src[1]) * 255
194
+
195
+ imgs.append(torch.cat(img_src, dim=0))
196
+ intrins.append(intrin)
197
+ extrinsics.append(torch.from_numpy(camera_to_lidar))
198
+ rots.append(rot)
199
+ trans.append(tran)
200
+ post_rots.append(post_rot)
201
+ post_trans.append(post_tran)
202
+
203
+
204
+ selected_cav_processed.update(
205
+ {
206
+ "image_inputs":
207
+ {
208
+ "imgs": torch.stack(imgs), # [Ncam, 3or4, H, W]
209
+ "intrins": torch.stack(intrins),
210
+ "extrinsics": torch.stack(extrinsics),
211
+ "rots": torch.stack(rots),
212
+ "trans": torch.stack(trans),
213
+ "post_rots": torch.stack(post_rots),
214
+ "post_trans": torch.stack(post_trans),
215
+ }
216
+ }
217
+ )
218
+
219
+ # anchor box
220
+ selected_cav_processed.update({"anchor_box": self.anchor_box})
221
+
222
+ # note the reference pose ego
223
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center([selected_cav_base],
224
+ ego_pose_clean)
225
+
226
+ selected_cav_processed.update(
227
+ {
228
+ "object_bbx_center": object_bbx_center[object_bbx_mask == 1],
229
+ "object_bbx_mask": object_bbx_mask,
230
+ "object_ids": object_ids,
231
+ 'transformation_matrix': transformation_matrix,
232
+ 'transformation_matrix_clean': transformation_matrix_clean
233
+ }
234
+ )
235
+
236
+
237
+ return selected_cav_processed
238
+
239
+ def __getitem__(self, idx):
240
+ base_data_dict = self.retrieve_base_data(idx)
241
+ base_data_dict = add_noise_data_dict(base_data_dict,self.params['noise_setting'])
242
+
243
+ processed_data_dict = OrderedDict()
244
+ processed_data_dict['ego'] = {}
245
+
246
+ ego_id = -1
247
+ ego_lidar_pose = []
248
+ ego_cav_base = None
249
+
250
+ # first find the ego vehicle's lidar pose
251
+ for cav_id, cav_content in base_data_dict.items():
252
+ if cav_content['ego']:
253
+ ego_id = cav_id
254
+ ego_lidar_pose = cav_content['params']['lidar_pose']
255
+ ego_cav_base = cav_content
256
+ break
257
+
258
+ assert cav_id == list(base_data_dict.keys())[
259
+ 0], "The first element in the OrderedDict must be ego"
260
+ assert ego_id != -1
261
+ assert len(ego_lidar_pose) > 0
262
+
263
+ agents_image_inputs = []
264
+ processed_features = []
265
+ object_stack = []
266
+ object_id_stack = []
267
+ single_label_list = []
268
+ single_object_bbx_center_list = []
269
+ single_object_bbx_mask_list = []
270
+ too_far = []
271
+ lidar_pose_list = []
272
+ lidar_pose_clean_list = []
273
+ cav_id_list = []
274
+ projected_lidar_clean_list = [] # disconet
275
+
276
+ if self.visualize or self.kd_flag:
277
+ projected_lidar_stack = []
278
+
279
+ # loop over all CAVs to process information
280
+ for cav_id, selected_cav_base in base_data_dict.items():
281
+ # check if the cav is within the communication range with ego
282
+ distance = \
283
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
284
+ ego_lidar_pose[0]) ** 2 + (
285
+ selected_cav_base['params'][
286
+ 'lidar_pose'][1] - ego_lidar_pose[
287
+ 1]) ** 2)
288
+
289
+ # if distance is too far, we will just skip this agent
290
+ if distance > self.params['comm_range']:
291
+ too_far.append(cav_id)
292
+ continue
293
+
294
+ lidar_pose_clean_list.append(selected_cav_base['params']['lidar_pose_clean'])
295
+ lidar_pose_list.append(selected_cav_base['params']['lidar_pose']) # 6dof pose
296
+ cav_id_list.append(cav_id)
297
+
298
+ for cav_id in too_far:
299
+ base_data_dict.pop(cav_id)
300
+
301
+ ########## Updated by Yifan Lu 2022.1.26 ############
302
+ # box align to correct pose.
303
+ # stage1_content contains all agent. Even out of comm range.
304
+ if self.box_align and str(idx) in self.stage1_result.keys():
305
+ from opencood.models.sub_modules.box_align_v2 import box_alignment_relative_sample_np
306
+ stage1_content = self.stage1_result[str(idx)]
307
+ if stage1_content is not None:
308
+ all_agent_id_list = stage1_content['cav_id_list'] # include those out of range
309
+ all_agent_corners_list = stage1_content['pred_corner3d_np_list']
310
+ all_agent_uncertainty_list = stage1_content['uncertainty_np_list']
311
+
312
+ cur_agent_id_list = cav_id_list
313
+ cur_agent_pose = [base_data_dict[cav_id]['params']['lidar_pose'] for cav_id in cav_id_list]
314
+ cur_agnet_pose = np.array(cur_agent_pose)
315
+ cur_agent_in_all_agent = [all_agent_id_list.index(cur_agent) for cur_agent in cur_agent_id_list] # indexing current agent in `all_agent_id_list`
316
+
317
+ pred_corners_list = [np.array(all_agent_corners_list[cur_in_all_ind], dtype=np.float64)
318
+ for cur_in_all_ind in cur_agent_in_all_agent]
319
+ uncertainty_list = [np.array(all_agent_uncertainty_list[cur_in_all_ind], dtype=np.float64)
320
+ for cur_in_all_ind in cur_agent_in_all_agent]
321
+
322
+ if sum([len(pred_corners) for pred_corners in pred_corners_list]) != 0:
323
+ refined_pose = box_alignment_relative_sample_np(pred_corners_list,
324
+ cur_agnet_pose,
325
+ uncertainty_list=uncertainty_list,
326
+ **self.box_align_args)
327
+ cur_agnet_pose[:,[0,1,4]] = refined_pose
328
+
329
+ for i, cav_id in enumerate(cav_id_list):
330
+ lidar_pose_list[i] = cur_agnet_pose[i].tolist()
331
+ base_data_dict[cav_id]['params']['lidar_pose'] = cur_agnet_pose[i].tolist()
332
+
333
+
334
+
335
+ pairwise_t_matrix = \
336
+ get_pairwise_transformation(base_data_dict,
337
+ self.max_cav,
338
+ self.proj_first)
339
+
340
+ lidar_poses = np.array(lidar_pose_list).reshape(-1, 6) # [N_cav, 6]
341
+ lidar_poses_clean = np.array(lidar_pose_clean_list).reshape(-1, 6) # [N_cav, 6]
342
+
343
+ # merge preprocessed features from different cavs into the same dict
344
+ cav_num = len(cav_id_list)
345
+
346
+ # heterogeneous
347
+ if self.heterogeneous:
348
+ lidar_agent, camera_agent = self.selector.select_agent(idx)
349
+ lidar_agent = lidar_agent[:cav_num]
350
+ processed_data_dict['ego'].update({"lidar_agent": lidar_agent})
351
+
352
+ for _i, cav_id in enumerate(cav_id_list):
353
+ selected_cav_base = base_data_dict[cav_id]
354
+
355
+ # dynamic object center generator! for heterogeneous input
356
+ if (not self.visualize) and self.heterogeneous and lidar_agent[_i]:
357
+ self.generate_object_center = self.generate_object_center_lidar
358
+ elif (not self.visualize) and self.heterogeneous and (not lidar_agent[_i]):
359
+ self.generate_object_center = self.generate_object_center_camera
360
+
361
+ selected_cav_processed = self.get_item_single_car(
362
+ selected_cav_base,
363
+ ego_cav_base)
364
+
365
+ object_stack.append(selected_cav_processed['object_bbx_center'])
366
+ object_id_stack += selected_cav_processed['object_ids']
367
+ if self.load_lidar_file:
368
+ processed_features.append(
369
+ selected_cav_processed['processed_features'])
370
+ if self.load_camera_file:
371
+ agents_image_inputs.append(
372
+ selected_cav_processed['image_inputs'])
373
+
374
+ if self.visualize or self.kd_flag:
375
+ projected_lidar_stack.append(
376
+ selected_cav_processed['projected_lidar'])
377
+
378
+ if self.supervise_single:
379
+ single_label_list.append(selected_cav_processed['single_label_dict'])
380
+ single_object_bbx_center_list.append(selected_cav_processed['single_object_bbx_center'])
381
+ single_object_bbx_mask_list.append(selected_cav_processed['single_object_bbx_mask'])
382
+
383
+ # generate single view GT label
384
+ if self.supervise_single:
385
+ single_label_dicts = self.post_processor.collate_batch(single_label_list)
386
+ single_object_bbx_center = torch.from_numpy(np.array(single_object_bbx_center_list))
387
+ single_object_bbx_mask = torch.from_numpy(np.array(single_object_bbx_mask_list))
388
+ processed_data_dict['ego'].update({
389
+ "single_label_dict_torch": single_label_dicts,
390
+ "single_object_bbx_center_torch": single_object_bbx_center,
391
+ "single_object_bbx_mask_torch": single_object_bbx_mask,
392
+ })
393
+
394
+ if self.kd_flag:
395
+ stack_lidar_np = np.vstack(projected_lidar_stack)
396
+ stack_lidar_np = mask_points_by_range(stack_lidar_np,
397
+ self.params['preprocess'][
398
+ 'cav_lidar_range'])
399
+ stack_feature_processed = self.pre_processor.preprocess(stack_lidar_np)
400
+ processed_data_dict['ego'].update({'teacher_processed_lidar':
401
+ stack_feature_processed})
402
+
403
+
404
+ # exclude all repetitive objects
405
+ unique_indices = \
406
+ [object_id_stack.index(x) for x in set(object_id_stack)]
407
+ object_stack = np.vstack(object_stack)
408
+ object_stack = object_stack[unique_indices]
409
+
410
+ # make sure bounding boxes across all frames have the same number
411
+ object_bbx_center = \
412
+ np.zeros((self.params['postprocess']['max_num'], 7))
413
+ mask = np.zeros(self.params['postprocess']['max_num'])
414
+ object_bbx_center[:object_stack.shape[0], :] = object_stack
415
+ mask[:object_stack.shape[0]] = 1
416
+
417
+ if self.load_lidar_file:
418
+ merged_feature_dict = merge_features_to_dict(processed_features)
419
+ processed_data_dict['ego'].update({'processed_lidar': merged_feature_dict})
420
+ if self.load_camera_file:
421
+ merged_image_inputs_dict = merge_features_to_dict(agents_image_inputs, merge='stack')
422
+ processed_data_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
423
+
424
+
425
+ # generate targets label
426
+ label_dict = \
427
+ self.post_processor.generate_label(
428
+ gt_box_center=object_bbx_center,
429
+ anchors=self.anchor_box,
430
+ mask=mask)
431
+
432
+ processed_data_dict['ego'].update(
433
+ {'object_bbx_center': object_bbx_center,
434
+ 'object_bbx_mask': mask,
435
+ 'object_ids': [object_id_stack[i] for i in unique_indices],
436
+ 'anchor_box': self.anchor_box,
437
+ 'label_dict': label_dict,
438
+ 'cav_num': cav_num,
439
+ 'pairwise_t_matrix': pairwise_t_matrix,
440
+ 'lidar_poses_clean': lidar_poses_clean,
441
+ 'lidar_poses': lidar_poses})
442
+
443
+
444
+ if self.visualize:
445
+ processed_data_dict['ego'].update({'origin_lidar':
446
+ np.vstack(
447
+ projected_lidar_stack)})
448
+
449
+
450
+ processed_data_dict['ego'].update({'sample_idx': idx,
451
+ 'cav_id_list': cav_id_list})
452
+
453
+ return processed_data_dict
454
+
455
+
456
+ def collate_batch_train(self, batch):
457
+ # Intermediate fusion is different the other two
458
+ output_dict = {'ego': {}}
459
+
460
+ object_bbx_center = []
461
+ object_bbx_mask = []
462
+ object_ids = []
463
+ processed_lidar_list = []
464
+ image_inputs_list = []
465
+ # used to record different scenario
466
+ record_len = []
467
+ label_dict_list = []
468
+ lidar_pose_list = []
469
+ origin_lidar = []
470
+ lidar_pose_clean_list = []
471
+
472
+ # heterogeneous
473
+ lidar_agent_list = []
474
+
475
+ # pairwise transformation matrix
476
+ pairwise_t_matrix_list = []
477
+
478
+ # disconet
479
+ teacher_processed_lidar_list = []
480
+
481
+ ### 2022.10.10 single gt ####
482
+ if self.supervise_single:
483
+ pos_equal_one_single = []
484
+ neg_equal_one_single = []
485
+ targets_single = []
486
+ object_bbx_center_single = []
487
+ object_bbx_mask_single = []
488
+
489
+ for i in range(len(batch)):
490
+ ego_dict = batch[i]['ego']
491
+ object_bbx_center.append(ego_dict['object_bbx_center'])
492
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
493
+ object_ids.append(ego_dict['object_ids'])
494
+ lidar_pose_list.append(ego_dict['lidar_poses']) # ego_dict['lidar_pose'] is np.ndarray [N,6]
495
+ lidar_pose_clean_list.append(ego_dict['lidar_poses_clean'])
496
+ if self.load_lidar_file:
497
+ processed_lidar_list.append(ego_dict['processed_lidar'])
498
+ if self.load_camera_file:
499
+ image_inputs_list.append(ego_dict['image_inputs']) # different cav_num, ego_dict['image_inputs'] is dict.
500
+
501
+ record_len.append(ego_dict['cav_num'])
502
+ label_dict_list.append(ego_dict['label_dict'])
503
+ pairwise_t_matrix_list.append(ego_dict['pairwise_t_matrix'])
504
+
505
+ if self.visualize:
506
+ origin_lidar.append(ego_dict['origin_lidar'])
507
+
508
+ if self.kd_flag:
509
+ teacher_processed_lidar_list.append(ego_dict['teacher_processed_lidar'])
510
+
511
+ ### 2022.10.10 single gt ####
512
+ if self.supervise_single:
513
+ pos_equal_one_single.append(ego_dict['single_label_dict_torch']['pos_equal_one'])
514
+ neg_equal_one_single.append(ego_dict['single_label_dict_torch']['neg_equal_one'])
515
+ targets_single.append(ego_dict['single_label_dict_torch']['targets'])
516
+ object_bbx_center_single.append(ego_dict['single_object_bbx_center_torch'])
517
+ object_bbx_mask_single.append(ego_dict['single_object_bbx_mask_torch'])
518
+
519
+ # heterogeneous
520
+ if self.heterogeneous:
521
+ lidar_agent_list.append(ego_dict['lidar_agent'])
522
+
523
+ # convert to numpy, (B, max_num, 7)
524
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
525
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
526
+
527
+ if self.load_lidar_file:
528
+ merged_feature_dict = merge_features_to_dict(processed_lidar_list)
529
+
530
+ if self.heterogeneous:
531
+ lidar_agent = np.concatenate(lidar_agent_list)
532
+ lidar_agent_idx = lidar_agent.nonzero()[0].tolist()
533
+ for k, v in merged_feature_dict.items(): # 'voxel_features' 'voxel_num_points' 'voxel_coords'
534
+ merged_feature_dict[k] = [v[index] for index in lidar_agent_idx]
535
+
536
+ if not self.heterogeneous or (self.heterogeneous and sum(lidar_agent) != 0):
537
+ processed_lidar_torch_dict = \
538
+ self.pre_processor.collate_batch(merged_feature_dict)
539
+ output_dict['ego'].update({'processed_lidar': processed_lidar_torch_dict})
540
+
541
+ if self.load_camera_file:
542
+ merged_image_inputs_dict = merge_features_to_dict(image_inputs_list, merge='cat')
543
+
544
+ if self.heterogeneous:
545
+ lidar_agent = np.concatenate(lidar_agent_list)
546
+ camera_agent = 1 - lidar_agent
547
+ camera_agent_idx = camera_agent.nonzero()[0].tolist()
548
+ if sum(camera_agent) != 0:
549
+ for k, v in merged_image_inputs_dict.items(): # 'imgs' 'rots' 'trans' ...
550
+ merged_image_inputs_dict[k] = torch.stack([v[index] for index in camera_agent_idx])
551
+
552
+ if not self.heterogeneous or (self.heterogeneous and sum(camera_agent) != 0):
553
+ output_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
554
+
555
+ record_len = torch.from_numpy(np.array(record_len, dtype=int))
556
+ lidar_pose = torch.from_numpy(np.concatenate(lidar_pose_list, axis=0))
557
+ lidar_pose_clean = torch.from_numpy(np.concatenate(lidar_pose_clean_list, axis=0))
558
+ label_torch_dict = \
559
+ self.post_processor.collate_batch(label_dict_list)
560
+
561
+ # for centerpoint
562
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
563
+ 'object_bbx_mask': object_bbx_mask})
564
+
565
+ # (B, max_cav)
566
+ pairwise_t_matrix = torch.from_numpy(np.array(pairwise_t_matrix_list))
567
+
568
+ # add pairwise_t_matrix to label dict
569
+ label_torch_dict['pairwise_t_matrix'] = pairwise_t_matrix
570
+ label_torch_dict['record_len'] = record_len
571
+
572
+
573
+ # object id is only used during inference, where batch size is 1.
574
+ # so here we only get the first element.
575
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
576
+ 'object_bbx_mask': object_bbx_mask,
577
+ 'record_len': record_len,
578
+ 'label_dict': label_torch_dict,
579
+ 'object_ids': object_ids[0],
580
+ 'pairwise_t_matrix': pairwise_t_matrix,
581
+ 'lidar_pose_clean': lidar_pose_clean,
582
+ 'lidar_pose': lidar_pose,
583
+ 'anchor_box': self.anchor_box_torch})
584
+
585
+
586
+ if self.visualize:
587
+ origin_lidar = \
588
+ np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
589
+ origin_lidar = torch.from_numpy(origin_lidar)
590
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
591
+
592
+ if self.kd_flag:
593
+ teacher_processed_lidar_torch_dict = \
594
+ self.pre_processor.collate_batch(teacher_processed_lidar_list)
595
+ output_dict['ego'].update({'teacher_processed_lidar':teacher_processed_lidar_torch_dict})
596
+
597
+
598
+ if self.supervise_single:
599
+ output_dict['ego'].update({
600
+ "label_dict_single":{
601
+ "pos_equal_one": torch.cat(pos_equal_one_single, dim=0),
602
+ "neg_equal_one": torch.cat(neg_equal_one_single, dim=0),
603
+ "targets": torch.cat(targets_single, dim=0),
604
+ # for centerpoint
605
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
606
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
607
+ },
608
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
609
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
610
+ })
611
+
612
+ if self.heterogeneous:
613
+ output_dict['ego'].update({
614
+ "lidar_agent_record": torch.from_numpy(np.concatenate(lidar_agent_list)) # [0,1,1,0,1...]
615
+ })
616
+
617
+ return output_dict
618
+
619
+ def collate_batch_test(self, batch):
620
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
621
+ output_dict = self.collate_batch_train(batch)
622
+ if output_dict is None:
623
+ return None
624
+
625
+ # check if anchor box in the batch
626
+ if batch[0]['ego']['anchor_box'] is not None:
627
+ output_dict['ego'].update({'anchor_box':
628
+ self.anchor_box_torch})
629
+
630
+ # save the transformation matrix (4, 4) to ego vehicle
631
+ # transformation is only used in post process (no use.)
632
+ # we all predict boxes in ego coord.
633
+ transformation_matrix_torch = \
634
+ torch.from_numpy(np.identity(4)).float()
635
+ transformation_matrix_clean_torch = \
636
+ torch.from_numpy(np.identity(4)).float()
637
+
638
+ output_dict['ego'].update({'transformation_matrix':
639
+ transformation_matrix_torch,
640
+ 'transformation_matrix_clean':
641
+ transformation_matrix_clean_torch,})
642
+
643
+ output_dict['ego'].update({
644
+ "sample_idx": batch[0]['ego']['sample_idx'],
645
+ "cav_id_list": batch[0]['ego']['cav_id_list']
646
+ })
647
+
648
+ return output_dict
649
+
650
+
651
+ def post_process(self, data_dict, output_dict):
652
+ """
653
+ Process the outputs of the model to 2D/3D bounding box.
654
+
655
+ Parameters
656
+ ----------
657
+ data_dict : dict
658
+ The dictionary containing the origin input data of model.
659
+
660
+ output_dict :dict
661
+ The dictionary containing the output of the model.
662
+
663
+ Returns
664
+ -------
665
+ pred_box_tensor : torch.Tensor
666
+ The tensor of prediction bounding box after NMS.
667
+ gt_box_tensor : torch.Tensor
668
+ The tensor of gt bounding box.
669
+ """
670
+ pred_box_tensor, pred_score = \
671
+ self.post_processor.post_process(data_dict, output_dict)
672
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
673
+
674
+ return pred_box_tensor, pred_score, gt_box_tensor
675
+
676
+
677
+ return IntermediateFusionDataset
678
+
679
+
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/intermediate_heter_fusion_dataset.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ intermediate heter fusion dataset
3
+
4
+ Note that for DAIR-V2X dataset,
5
+ Each agent should retrieve the objects itself, and merge them by iou,
6
+ instead of using the cooperative label.
7
+ '''
8
+
9
+ import random
10
+ import math
11
+ from collections import OrderedDict
12
+ import numpy as np
13
+ import torch
14
+ import copy
15
+ from icecream import ic
16
+ from PIL import Image
17
+ import pickle as pkl
18
+ from opencood.utils import box_utils as box_utils
19
+ from opencood.data_utils.pre_processor import build_preprocessor
20
+ from opencood.data_utils.post_processor import build_postprocessor
21
+ from opencood.utils.camera_utils import (
22
+ sample_augmentation,
23
+ img_transform,
24
+ normalize_img,
25
+ img_to_tensor,
26
+ )
27
+ from opencood.utils.common_utils import merge_features_to_dict, compute_iou, convert_format
28
+ from opencood.utils.transformation_utils import x1_to_x2, x_to_world, get_pairwise_transformation
29
+ from opencood.utils.pose_utils import add_noise_data_dict
30
+ from opencood.data_utils.pre_processor import build_preprocessor
31
+ from opencood.utils.pcd_utils import (
32
+ mask_points_by_range,
33
+ mask_ego_points,
34
+ shuffle_points,
35
+ downsample_lidar_minimum,
36
+ )
37
+ from opencood.utils.common_utils import read_json
38
+ from opencood.utils.heter_utils import Adaptor
39
+
40
+
41
+ def getIntermediateheterFusionDataset(cls):
42
+ """
43
+ cls: the Basedataset.
44
+ """
45
+ class IntermediateheterFusionDataset(cls):
46
+ def __init__(self, params, visualize, train=True):
47
+ super().__init__(params, visualize, train)
48
+ # intermediate and supervise single
49
+ self.supervise_single = True if ('supervise_single' in params['model']['args'] and params['model']['args']['supervise_single']) \
50
+ else False
51
+ self.proj_first = False if 'proj_first' not in params['fusion']['args']\
52
+ else params['fusion']['args']['proj_first']
53
+
54
+ self.anchor_box = self.post_processor.generate_anchor_box()
55
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
56
+
57
+ self.heterogeneous = True
58
+ self.modality_assignment = read_json(params['heter']['assignment_path'])
59
+ self.ego_modality = params['heter']['ego_modality'] # "m1" or "m1&m2" or "m3"
60
+
61
+ self.modality_name_list = list(params['heter']['modality_setting'].keys())
62
+ self.sensor_type_dict = OrderedDict()
63
+
64
+ lidar_channels_dict = params['heter'].get('lidar_channels_dict', OrderedDict())
65
+ mapping_dict = params['heter']['mapping_dict']
66
+ cav_preference = params['heter'].get("cav_preference", None)
67
+
68
+ self.adaptor = Adaptor(self.ego_modality,
69
+ self.modality_name_list,
70
+ self.modality_assignment,
71
+ lidar_channels_dict,
72
+ mapping_dict,
73
+ cav_preference,
74
+ train)
75
+
76
+ for modality_name, modal_setting in params['heter']['modality_setting'].items():
77
+ self.sensor_type_dict[modality_name] = modal_setting['sensor_type']
78
+ if modal_setting['sensor_type'] == 'lidar':
79
+ setattr(self, f"pre_processor_{modality_name}", build_preprocessor(modal_setting['preprocess'], train))
80
+
81
+ elif modal_setting['sensor_type'] == 'camera':
82
+ setattr(self, f"data_aug_conf_{modality_name}", modal_setting['data_aug_conf'])
83
+
84
+ else:
85
+ raise("Not support this type of sensor")
86
+
87
+ self.reinitialize()
88
+
89
+
90
+ self.kd_flag = params.get('kd_flag', False)
91
+
92
+ self.box_align = False
93
+ if "box_align" in params:
94
+ self.box_align = True
95
+ self.stage1_result_path = params['box_align']['train_result'] if train else params['box_align']['val_result']
96
+ self.stage1_result = read_json(self.stage1_result_path)
97
+ self.box_align_args = params['box_align']['args']
98
+
99
+
100
+
101
+ def get_item_single_car(self, selected_cav_base, ego_cav_base):
102
+ """
103
+ Process a single CAV's information for the train/test pipeline.
104
+
105
+
106
+ Parameters
107
+ ----------
108
+ selected_cav_base : dict
109
+ The dictionary contains a single CAV's raw information.
110
+ including 'params', 'camera_data'
111
+ ego_pose : list, length 6
112
+ The ego vehicle lidar pose under world coordinate.
113
+ ego_pose_clean : list, length 6
114
+ only used for gt box generation
115
+
116
+ Returns
117
+ -------
118
+ selected_cav_processed : dict
119
+ The dictionary contains the cav's processed information.
120
+ """
121
+ selected_cav_processed = {}
122
+ ego_pose, ego_pose_clean = ego_cav_base['params']['lidar_pose'], ego_cav_base['params']['lidar_pose_clean']
123
+
124
+ # calculate the transformation matrix
125
+ transformation_matrix = \
126
+ x1_to_x2(selected_cav_base['params']['lidar_pose'],
127
+ ego_pose) # T_ego_cav
128
+ transformation_matrix_clean = \
129
+ x1_to_x2(selected_cav_base['params']['lidar_pose_clean'],
130
+ ego_pose_clean)
131
+
132
+ modality_name = selected_cav_base['modality_name']
133
+ sensor_type = self.sensor_type_dict[modality_name]
134
+
135
+ # lidar
136
+ if sensor_type == "lidar" or self.visualize:
137
+ # process lidar
138
+ lidar_np = selected_cav_base['lidar_np']
139
+ lidar_np = shuffle_points(lidar_np)
140
+ # remove points that hit itself
141
+ lidar_np = mask_ego_points(lidar_np)
142
+ # project the lidar to ego space
143
+ # x,y,z in ego space
144
+ projected_lidar = \
145
+ box_utils.project_points_by_matrix_torch(lidar_np[:, :3],
146
+ transformation_matrix)
147
+ if self.proj_first:
148
+ lidar_np[:, :3] = projected_lidar
149
+
150
+ if self.visualize:
151
+ # filter lidar
152
+ selected_cav_processed.update({'projected_lidar': projected_lidar})
153
+
154
+ if self.kd_flag:
155
+ lidar_proj_np = copy.deepcopy(lidar_np)
156
+ lidar_proj_np[:,:3] = projected_lidar
157
+
158
+ selected_cav_processed.update({'projected_lidar': lidar_proj_np})
159
+
160
+ if sensor_type == "lidar":
161
+ processed_lidar = eval(f"self.pre_processor_{modality_name}").preprocess(lidar_np)
162
+ selected_cav_processed.update({f'processed_features_{modality_name}': processed_lidar})
163
+
164
+ # generate targets label single GT, note the reference pose is itself.
165
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center(
166
+ [selected_cav_base], selected_cav_base['params']['lidar_pose']
167
+ )
168
+ label_dict = self.post_processor.generate_label(
169
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
170
+ )
171
+ selected_cav_processed.update({
172
+ "single_label_dict": label_dict,
173
+ "single_object_bbx_center": object_bbx_center,
174
+ "single_object_bbx_mask": object_bbx_mask})
175
+
176
+ # camera
177
+ if sensor_type == "camera":
178
+ camera_data_list = selected_cav_base["camera_data"]
179
+ params = selected_cav_base["params"]
180
+ imgs = []
181
+ rots = []
182
+ trans = []
183
+ intrins = []
184
+ extrinsics = []
185
+ post_rots = []
186
+ post_trans = []
187
+
188
+ for idx, img in enumerate(camera_data_list):
189
+ camera_to_lidar, camera_intrinsic = self.get_ext_int(params, idx)
190
+
191
+ intrin = torch.from_numpy(camera_intrinsic)
192
+ rot = torch.from_numpy(
193
+ camera_to_lidar[:3, :3]
194
+ ) # R_wc, we consider world-coord is the lidar-coord
195
+ tran = torch.from_numpy(camera_to_lidar[:3, 3]) # T_wc
196
+
197
+ post_rot = torch.eye(2)
198
+ post_tran = torch.zeros(2)
199
+
200
+ img_src = [img]
201
+
202
+ # depth
203
+ if self.load_depth_file:
204
+ depth_img = selected_cav_base["depth_data"][idx]
205
+ img_src.append(depth_img)
206
+ else:
207
+ depth_img = None
208
+
209
+ # data augmentation
210
+ resize, resize_dims, crop, flip, rotate = sample_augmentation(
211
+ eval(f"self.data_aug_conf_{modality_name}"), self.train
212
+ )
213
+ img_src, post_rot2, post_tran2 = img_transform(
214
+ img_src,
215
+ post_rot,
216
+ post_tran,
217
+ resize=resize,
218
+ resize_dims=resize_dims,
219
+ crop=crop,
220
+ flip=flip,
221
+ rotate=rotate,
222
+ )
223
+ # for convenience, make augmentation matrices 3x3
224
+ post_tran = torch.zeros(3)
225
+ post_rot = torch.eye(3)
226
+ post_tran[:2] = post_tran2
227
+ post_rot[:2, :2] = post_rot2
228
+
229
+ # decouple RGB and Depth
230
+
231
+ img_src[0] = normalize_img(img_src[0])
232
+ if self.load_depth_file:
233
+ img_src[1] = img_to_tensor(img_src[1]) * 255
234
+
235
+ imgs.append(torch.cat(img_src, dim=0))
236
+ intrins.append(intrin)
237
+ extrinsics.append(torch.from_numpy(camera_to_lidar))
238
+ rots.append(rot)
239
+ trans.append(tran)
240
+ post_rots.append(post_rot)
241
+ post_trans.append(post_tran)
242
+
243
+
244
+ selected_cav_processed.update(
245
+ {
246
+ f"image_inputs_{modality_name}":
247
+ {
248
+ "imgs": torch.stack(imgs), # [Ncam, 3or4, H, W]
249
+ "intrins": torch.stack(intrins),
250
+ "extrinsics": torch.stack(extrinsics),
251
+ "rots": torch.stack(rots),
252
+ "trans": torch.stack(trans),
253
+ "post_rots": torch.stack(post_rots),
254
+ "post_trans": torch.stack(post_trans),
255
+ }
256
+ }
257
+ )
258
+
259
+ # anchor box
260
+ selected_cav_processed.update({"anchor_box": self.anchor_box})
261
+
262
+ # note the reference pose ego
263
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center([selected_cav_base],
264
+ ego_pose_clean)
265
+
266
+ selected_cav_processed.update(
267
+ {
268
+ "object_bbx_center": object_bbx_center[object_bbx_mask == 1],
269
+ "object_bbx_mask": object_bbx_mask,
270
+ "object_ids": object_ids,
271
+ 'transformation_matrix': transformation_matrix,
272
+ 'transformation_matrix_clean': transformation_matrix_clean
273
+ }
274
+ )
275
+
276
+
277
+ return selected_cav_processed
278
+
279
+ def __getitem__(self, idx):
280
+ base_data_dict = self.retrieve_base_data(idx)
281
+ base_data_dict = add_noise_data_dict(base_data_dict,self.params['noise_setting'])
282
+
283
+ processed_data_dict = OrderedDict()
284
+ processed_data_dict['ego'] = {}
285
+
286
+ ego_id = -1
287
+ ego_lidar_pose = []
288
+ ego_cav_base = None
289
+
290
+ # first find the ego vehicle's lidar pose
291
+ for cav_id, cav_content in base_data_dict.items():
292
+ if cav_content['ego']:
293
+ ego_id = cav_id
294
+ ego_lidar_pose = cav_content['params']['lidar_pose']
295
+ ego_cav_base = cav_content
296
+ break
297
+
298
+ assert cav_id == list(base_data_dict.keys())[
299
+ 0], "The first element in the OrderedDict must be ego"
300
+ assert ego_id != -1
301
+ assert len(ego_lidar_pose) > 0
302
+
303
+
304
+ input_list_m1 = [] # can contain lidar or camera
305
+ input_list_m2 = []
306
+ input_list_m3 = []
307
+ input_list_m4 = []
308
+
309
+ agent_modality_list = []
310
+ object_stack = []
311
+ object_id_stack = []
312
+ single_label_list = []
313
+ single_object_bbx_center_list = []
314
+ single_object_bbx_mask_list = []
315
+ exclude_agent = []
316
+ lidar_pose_list = []
317
+ lidar_pose_clean_list = []
318
+ cav_id_list = []
319
+ projected_lidar_clean_list = [] # disconet
320
+
321
+ if self.visualize or self.kd_flag:
322
+ projected_lidar_stack = []
323
+
324
+ # loop over all CAVs to process information
325
+ for cav_id, selected_cav_base in base_data_dict.items():
326
+ # check if the cav is within the communication range with ego
327
+ distance = \
328
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
329
+ ego_lidar_pose[0]) ** 2 + (
330
+ selected_cav_base['params'][
331
+ 'lidar_pose'][1] - ego_lidar_pose[
332
+ 1]) ** 2)
333
+
334
+ # if distance is too far, we will just skip this agent
335
+ if distance > self.params['comm_range']:
336
+ exclude_agent.append(cav_id)
337
+ continue
338
+
339
+ # if modality not match
340
+ if self.adaptor.unmatched_modality(selected_cav_base['modality_name']):
341
+ exclude_agent.append(cav_id)
342
+ continue
343
+
344
+ lidar_pose_clean_list.append(selected_cav_base['params']['lidar_pose_clean'])
345
+ lidar_pose_list.append(selected_cav_base['params']['lidar_pose']) # 6dof pose
346
+ cav_id_list.append(cav_id)
347
+
348
+ if len(cav_id_list) == 0:
349
+ return None
350
+
351
+ for cav_id in exclude_agent:
352
+ base_data_dict.pop(cav_id)
353
+
354
+ ########## Updated by Yifan Lu 2022.1.26 ############
355
+ # box align to correct pose.
356
+ # stage1_content contains all agent. Even out of comm range.
357
+ if self.box_align and str(idx) in self.stage1_result.keys():
358
+ from opencood.models.sub_modules.box_align_v2 import box_alignment_relative_sample_np
359
+ stage1_content = self.stage1_result[str(idx)]
360
+ if stage1_content is not None:
361
+ all_agent_id_list = stage1_content['cav_id_list'] # include those out of range
362
+ all_agent_corners_list = stage1_content['pred_corner3d_np_list']
363
+ all_agent_uncertainty_list = stage1_content['uncertainty_np_list']
364
+
365
+ cur_agent_id_list = cav_id_list
366
+ cur_agent_pose = [base_data_dict[cav_id]['params']['lidar_pose'] for cav_id in cav_id_list]
367
+ cur_agnet_pose = np.array(cur_agent_pose)
368
+ cur_agent_in_all_agent = [all_agent_id_list.index(cur_agent) for cur_agent in cur_agent_id_list] # indexing current agent in `all_agent_id_list`
369
+
370
+ pred_corners_list = [np.array(all_agent_corners_list[cur_in_all_ind], dtype=np.float64)
371
+ for cur_in_all_ind in cur_agent_in_all_agent]
372
+ uncertainty_list = [np.array(all_agent_uncertainty_list[cur_in_all_ind], dtype=np.float64)
373
+ for cur_in_all_ind in cur_agent_in_all_agent]
374
+
375
+ if sum([len(pred_corners) for pred_corners in pred_corners_list]) != 0:
376
+ refined_pose = box_alignment_relative_sample_np(pred_corners_list,
377
+ cur_agnet_pose,
378
+ uncertainty_list=uncertainty_list,
379
+ **self.box_align_args)
380
+ cur_agnet_pose[:,[0,1,4]] = refined_pose
381
+
382
+ for i, cav_id in enumerate(cav_id_list):
383
+ lidar_pose_list[i] = cur_agnet_pose[i].tolist()
384
+ base_data_dict[cav_id]['params']['lidar_pose'] = cur_agnet_pose[i].tolist()
385
+
386
+
387
+
388
+ pairwise_t_matrix = \
389
+ get_pairwise_transformation(base_data_dict,
390
+ self.max_cav,
391
+ self.proj_first)
392
+
393
+ lidar_poses = np.array(lidar_pose_list).reshape(-1, 6) # [N_cav, 6]
394
+ lidar_poses_clean = np.array(lidar_pose_clean_list).reshape(-1, 6) # [N_cav, 6]
395
+
396
+ # merge preprocessed features from different cavs into the same dict
397
+ cav_num = len(cav_id_list)
398
+
399
+ for _i, cav_id in enumerate(cav_id_list):
400
+ selected_cav_base = base_data_dict[cav_id]
401
+ modality_name = selected_cav_base['modality_name']
402
+ sensor_type = self.sensor_type_dict[selected_cav_base['modality_name']]
403
+
404
+ # dynamic object center generator! for heterogeneous input
405
+ if not self.visualize:
406
+ self.generate_object_center = eval(f"self.generate_object_center_{sensor_type}")
407
+ # need discussion. In test phase, use lidar label.
408
+ else:
409
+ self.generate_object_center = self.generate_object_center_lidar
410
+
411
+ selected_cav_processed = self.get_item_single_car(
412
+ selected_cav_base,
413
+ ego_cav_base)
414
+
415
+ object_stack.append(selected_cav_processed['object_bbx_center'])
416
+ object_id_stack += selected_cav_processed['object_ids']
417
+
418
+
419
+ if sensor_type == "lidar":
420
+ eval(f"input_list_{modality_name}").append(selected_cav_processed[f"processed_features_{modality_name}"])
421
+ elif sensor_type == "camera":
422
+ eval(f"input_list_{modality_name}").append(selected_cav_processed[f"image_inputs_{modality_name}"])
423
+ else:
424
+ raise
425
+
426
+ agent_modality_list.append(modality_name)
427
+
428
+ if self.visualize or self.kd_flag:
429
+ projected_lidar_stack.append(
430
+ selected_cav_processed['projected_lidar'])
431
+
432
+ if self.supervise_single or self.heterogeneous:
433
+ single_label_list.append(selected_cav_processed['single_label_dict'])
434
+ single_object_bbx_center_list.append(selected_cav_processed['single_object_bbx_center'])
435
+ single_object_bbx_mask_list.append(selected_cav_processed['single_object_bbx_mask'])
436
+
437
+ # generate single view GT label
438
+ if self.supervise_single or self.heterogeneous:
439
+ single_label_dicts = self.post_processor.collate_batch(single_label_list)
440
+ single_object_bbx_center = torch.from_numpy(np.array(single_object_bbx_center_list))
441
+ single_object_bbx_mask = torch.from_numpy(np.array(single_object_bbx_mask_list))
442
+ processed_data_dict['ego'].update({
443
+ "single_label_dict_torch": single_label_dicts,
444
+ "single_object_bbx_center_torch": single_object_bbx_center,
445
+ "single_object_bbx_mask_torch": single_object_bbx_mask,
446
+ })
447
+
448
+ if self.kd_flag:
449
+ stack_lidar_np = np.vstack(projected_lidar_stack)
450
+ stack_lidar_np = mask_points_by_range(stack_lidar_np,
451
+ self.params['preprocess'][
452
+ 'cav_lidar_range'])
453
+ stack_feature_processed = self.pre_processor.preprocess(stack_lidar_np)
454
+ processed_data_dict['ego'].update({'teacher_processed_lidar':
455
+ stack_feature_processed})
456
+
457
+
458
+ # exculude all repetitve objects, DAIR-V2X
459
+ if self.params['fusion']['dataset'] == 'dairv2x':
460
+ if len(object_stack) == 1:
461
+ object_stack = object_stack[0]
462
+ else:
463
+ ego_boxes_np = object_stack[0]
464
+ cav_boxes_np = object_stack[1]
465
+ order = self.params['postprocess']['order']
466
+ ego_corners_np = box_utils.boxes_to_corners_3d(ego_boxes_np, order)
467
+ cav_corners_np = box_utils.boxes_to_corners_3d(cav_boxes_np, order)
468
+ ego_polygon_list = list(convert_format(ego_corners_np))
469
+ cav_polygon_list = list(convert_format(cav_corners_np))
470
+ iou_thresh = 0.05
471
+
472
+
473
+ gt_boxes_from_cav = []
474
+ for i in range(len(cav_polygon_list)):
475
+ cav_polygon = cav_polygon_list[i]
476
+ ious = compute_iou(cav_polygon, ego_polygon_list)
477
+ if (ious > iou_thresh).any():
478
+ continue
479
+ gt_boxes_from_cav.append(cav_boxes_np[i])
480
+
481
+ if len(gt_boxes_from_cav):
482
+ object_stack_from_cav = np.stack(gt_boxes_from_cav)
483
+ object_stack = np.vstack([ego_boxes_np, object_stack_from_cav])
484
+ else:
485
+ object_stack = ego_boxes_np
486
+
487
+ unique_indices = np.arange(object_stack.shape[0])
488
+ object_id_stack = np.arange(object_stack.shape[0])
489
+ else:
490
+ # exclude all repetitive objects, OPV2V-H
491
+ unique_indices = \
492
+ [object_id_stack.index(x) for x in set(object_id_stack)]
493
+ object_stack = np.vstack(object_stack)
494
+ object_stack = object_stack[unique_indices]
495
+
496
+ # make sure bounding boxes across all frames have the same number
497
+ object_bbx_center = \
498
+ np.zeros((self.params['postprocess']['max_num'], 7))
499
+ mask = np.zeros(self.params['postprocess']['max_num'])
500
+ object_bbx_center[:object_stack.shape[0], :] = object_stack
501
+ mask[:object_stack.shape[0]] = 1
502
+
503
+ for modality_name in self.modality_name_list:
504
+ if self.sensor_type_dict[modality_name] == "lidar":
505
+ merged_feature_dict = merge_features_to_dict(eval(f"input_list_{modality_name}"))
506
+ processed_data_dict['ego'].update({f'input_{modality_name}': merged_feature_dict}) # maybe None
507
+ elif self.sensor_type_dict[modality_name] == "camera":
508
+ merged_image_inputs_dict = merge_features_to_dict(eval(f"input_list_{modality_name}"), merge='stack')
509
+ processed_data_dict['ego'].update({f'input_{modality_name}': merged_image_inputs_dict}) # maybe None
510
+
511
+ processed_data_dict['ego'].update({'agent_modality_list': agent_modality_list})
512
+
513
+ # generate targets label
514
+ label_dict = \
515
+ self.post_processor.generate_label(
516
+ gt_box_center=object_bbx_center,
517
+ anchors=self.anchor_box,
518
+ mask=mask)
519
+
520
+ processed_data_dict['ego'].update(
521
+ {'object_bbx_center': object_bbx_center,
522
+ 'object_bbx_mask': mask,
523
+ 'object_ids': [object_id_stack[i] for i in unique_indices],
524
+ 'anchor_box': self.anchor_box,
525
+ 'label_dict': label_dict,
526
+ 'cav_num': cav_num,
527
+ 'pairwise_t_matrix': pairwise_t_matrix,
528
+ 'lidar_poses_clean': lidar_poses_clean,
529
+ 'lidar_poses': lidar_poses})
530
+
531
+
532
+ if self.visualize:
533
+ processed_data_dict['ego'].update({'origin_lidar':
534
+ np.vstack(
535
+ projected_lidar_stack)})
536
+
537
+
538
+ processed_data_dict['ego'].update({'sample_idx': idx,
539
+ 'cav_id_list': cav_id_list})
540
+
541
+ return processed_data_dict
542
+
543
+
544
+ def collate_batch_train(self, batch):
545
+ # Intermediate fusion is different the other two
546
+ output_dict = {'ego': {}}
547
+
548
+ object_bbx_center = []
549
+ object_bbx_mask = []
550
+ object_ids = []
551
+ inputs_list_m1 = []
552
+ inputs_list_m2 = []
553
+ inputs_list_m3 = []
554
+ inputs_list_m4 = []
555
+ agent_modality_list = []
556
+ # used to record different scenario
557
+ record_len = []
558
+ label_dict_list = []
559
+ lidar_pose_list = []
560
+ origin_lidar = []
561
+ lidar_pose_clean_list = []
562
+
563
+ # pairwise transformation matrix
564
+ pairwise_t_matrix_list = []
565
+
566
+ # disconet
567
+ teacher_processed_lidar_list = []
568
+
569
+ ### 2022.10.10 single gt ####
570
+ if self.supervise_single or self.heterogeneous:
571
+ pos_equal_one_single = []
572
+ neg_equal_one_single = []
573
+ targets_single = []
574
+ object_bbx_center_single = []
575
+ object_bbx_mask_single = []
576
+
577
+ for i in range(len(batch)):
578
+ ego_dict = batch[i]['ego']
579
+ object_bbx_center.append(ego_dict['object_bbx_center'])
580
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
581
+ object_ids.append(ego_dict['object_ids'])
582
+ lidar_pose_list.append(ego_dict['lidar_poses']) # ego_dict['lidar_pose'] is np.ndarray [N,6]
583
+ lidar_pose_clean_list.append(ego_dict['lidar_poses_clean'])
584
+
585
+ for modality_name in self.modality_name_list:
586
+ if ego_dict[f'input_{modality_name}'] is not None:
587
+ eval(f"inputs_list_{modality_name}").append(ego_dict[f'input_{modality_name}']) # OrderedDict() if empty?
588
+
589
+ agent_modality_list.extend(ego_dict['agent_modality_list'])
590
+
591
+ record_len.append(ego_dict['cav_num'])
592
+ label_dict_list.append(ego_dict['label_dict'])
593
+ pairwise_t_matrix_list.append(ego_dict['pairwise_t_matrix'])
594
+
595
+ if self.visualize:
596
+ origin_lidar.append(ego_dict['origin_lidar'])
597
+
598
+ if self.kd_flag:
599
+ teacher_processed_lidar_list.append(ego_dict['teacher_processed_lidar'])
600
+
601
+ ### 2022.10.10 single gt ####
602
+ if self.supervise_single or self.heterogeneous:
603
+ pos_equal_one_single.append(ego_dict['single_label_dict_torch']['pos_equal_one'])
604
+ neg_equal_one_single.append(ego_dict['single_label_dict_torch']['neg_equal_one'])
605
+ targets_single.append(ego_dict['single_label_dict_torch']['targets'])
606
+ object_bbx_center_single.append(ego_dict['single_object_bbx_center_torch'])
607
+ object_bbx_mask_single.append(ego_dict['single_object_bbx_mask_torch'])
608
+
609
+
610
+ # convert to numpy, (B, max_num, 7)
611
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
612
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
613
+
614
+
615
+ # 2023.2.5
616
+ for modality_name in self.modality_name_list:
617
+ if len(eval(f"inputs_list_{modality_name}")) != 0:
618
+ if self.sensor_type_dict[modality_name] == "lidar":
619
+ merged_feature_dict = merge_features_to_dict(eval(f"inputs_list_{modality_name}"))
620
+ processed_lidar_torch_dict = eval(f"self.pre_processor_{modality_name}").collate_batch(merged_feature_dict)
621
+ output_dict['ego'].update({f'inputs_{modality_name}': processed_lidar_torch_dict})
622
+
623
+ elif self.sensor_type_dict[modality_name] == "camera":
624
+ merged_image_inputs_dict = merge_features_to_dict(eval(f"inputs_list_{modality_name}"), merge='cat')
625
+ output_dict['ego'].update({f'inputs_{modality_name}': merged_image_inputs_dict})
626
+
627
+
628
+ output_dict['ego'].update({"agent_modality_list": agent_modality_list})
629
+
630
+ record_len = torch.from_numpy(np.array(record_len, dtype=int))
631
+ lidar_pose = torch.from_numpy(np.concatenate(lidar_pose_list, axis=0))
632
+ lidar_pose_clean = torch.from_numpy(np.concatenate(lidar_pose_clean_list, axis=0))
633
+ label_torch_dict = \
634
+ self.post_processor.collate_batch(label_dict_list)
635
+
636
+ # for centerpoint
637
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
638
+ 'object_bbx_mask': object_bbx_mask})
639
+
640
+ # (B, max_cav)
641
+ pairwise_t_matrix = torch.from_numpy(np.array(pairwise_t_matrix_list))
642
+
643
+ # add pairwise_t_matrix to label dict
644
+ label_torch_dict['pairwise_t_matrix'] = pairwise_t_matrix
645
+ label_torch_dict['record_len'] = record_len
646
+
647
+
648
+ # object id is only used during inference, where batch size is 1.
649
+ # so here we only get the first element.
650
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
651
+ 'object_bbx_mask': object_bbx_mask,
652
+ 'record_len': record_len,
653
+ 'label_dict': label_torch_dict,
654
+ 'object_ids': object_ids[0],
655
+ 'pairwise_t_matrix': pairwise_t_matrix,
656
+ 'lidar_pose_clean': lidar_pose_clean,
657
+ 'lidar_pose': lidar_pose,
658
+ 'anchor_box': self.anchor_box_torch})
659
+
660
+
661
+ if self.visualize:
662
+ origin_lidar = \
663
+ np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
664
+ origin_lidar = torch.from_numpy(origin_lidar)
665
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
666
+
667
+ if self.kd_flag:
668
+ teacher_processed_lidar_torch_dict = \
669
+ self.pre_processor.collate_batch(teacher_processed_lidar_list)
670
+ output_dict['ego'].update({'teacher_processed_lidar':teacher_processed_lidar_torch_dict})
671
+
672
+
673
+ if self.supervise_single or self.heterogeneous:
674
+ output_dict['ego'].update({
675
+ "label_dict_single":{
676
+ "pos_equal_one": torch.cat(pos_equal_one_single, dim=0),
677
+ "neg_equal_one": torch.cat(neg_equal_one_single, dim=0),
678
+ "targets": torch.cat(targets_single, dim=0),
679
+ # for centerpoint
680
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
681
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
682
+ },
683
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
684
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
685
+ })
686
+
687
+ return output_dict
688
+
689
+ def collate_batch_test(self, batch):
690
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
691
+ if batch[0] is None:
692
+ return None
693
+ output_dict = self.collate_batch_train(batch)
694
+ if output_dict is None:
695
+ return None
696
+
697
+ # check if anchor box in the batch
698
+ if batch[0]['ego']['anchor_box'] is not None:
699
+ output_dict['ego'].update({'anchor_box':
700
+ self.anchor_box_torch})
701
+
702
+ # save the transformation matrix (4, 4) to ego vehicle
703
+ # transformation is only used in post process (no use.)
704
+ # we all predict boxes in ego coord.
705
+ transformation_matrix_torch = \
706
+ torch.from_numpy(np.identity(4)).float()
707
+ transformation_matrix_clean_torch = \
708
+ torch.from_numpy(np.identity(4)).float()
709
+
710
+ output_dict['ego'].update({'transformation_matrix':
711
+ transformation_matrix_torch,
712
+ 'transformation_matrix_clean':
713
+ transformation_matrix_clean_torch,})
714
+
715
+ output_dict['ego'].update({
716
+ "sample_idx": batch[0]['ego']['sample_idx'],
717
+ "cav_id_list": batch[0]['ego']['cav_id_list'],
718
+ "agent_modality_list": batch[0]['ego']['agent_modality_list']
719
+ })
720
+
721
+ return output_dict
722
+
723
+
724
+ def post_process(self, data_dict, output_dict):
725
+ """
726
+ Process the outputs of the model to 2D/3D bounding box.
727
+
728
+ Parameters
729
+ ----------
730
+ data_dict : dict
731
+ The dictionary containing the origin input data of model.
732
+
733
+ output_dict :dict
734
+ The dictionary containing the output of the model.
735
+
736
+ Returns
737
+ -------
738
+ pred_box_tensor : torch.Tensor
739
+ The tensor of prediction bounding box after NMS.
740
+ gt_box_tensor : torch.Tensor
741
+ The tensor of gt bounding box.
742
+ """
743
+ pred_box_tensor, pred_score = \
744
+ self.post_processor.post_process(data_dict, output_dict)
745
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
746
+
747
+ return pred_box_tensor, pred_score, gt_box_tensor
748
+
749
+
750
+ return IntermediateheterFusionDataset
751
+
752
+
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/intermediate_multiclass_fusion_dataset.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # intermediate fusion dataset
2
+ import random
3
+ import math
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+ import torch
7
+ import copy
8
+ from icecream import ic
9
+ from PIL import Image
10
+ import pickle as pkl
11
+ from opencood.utils import box_utils as box_utils
12
+ from opencood.data_utils.pre_processor import build_preprocessor
13
+ from opencood.data_utils.post_processor import build_postprocessor
14
+ from opencood.utils.camera_utils import (
15
+ sample_augmentation,
16
+ img_transform,
17
+ normalize_img,
18
+ img_to_tensor,
19
+ )
20
+ # from opencood.utils.heter_utils import AgentSelector
21
+ from opencood.utils.common_utils import merge_features_to_dict
22
+ from opencood.utils.transformation_utils import x1_to_x2, x_to_world, get_pairwise_transformation, get_pairwise_transformation_asymmetric
23
+ from opencood.utils.pose_utils import add_noise_data_dict, add_noise_data_dict_asymmetric
24
+ from opencood.utils.pcd_utils import (
25
+ mask_points_by_range,
26
+ mask_ego_points,
27
+ mask_ego_points_v2,
28
+ shuffle_points,
29
+ downsample_lidar_minimum,
30
+ )
31
+ from opencood.utils.common_utils import read_json
32
+
33
+
34
+ def getIntermediatemulticlassFusionDataset(cls):
35
+ """
36
+ cls: the Basedataset.
37
+ """
38
+ class IntermediatemulticlassFusionDataset(cls):
39
+ def __init__(self, params, visualize, train=True):
40
+ super().__init__(params, visualize, train)
41
+ # intermediate and supervise single
42
+ self.supervise_single = True if ('supervise_single' in params['model']['args'] and params['model']['args']['supervise_single']) \
43
+ else False
44
+ self.proj_first = False if 'proj_first' not in params['fusion']['args']\
45
+ else params['fusion']['args']['proj_first']
46
+
47
+ self.anchor_box = self.post_processor.generate_anchor_box()
48
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
49
+
50
+ self.heterogeneous = False
51
+ if 'heter' in params:
52
+ self.heterogeneous = True
53
+ self.selector = AgentSelector(params['heter'], self.max_cav)
54
+
55
+ self.kd_flag = params.get('kd_flag', False)
56
+
57
+ self.box_align = False
58
+ if "box_align" in params:
59
+ self.box_align = True
60
+ self.stage1_result_path = params['box_align']['train_result'] if train else params['box_align']['val_result']
61
+ self.stage1_result = read_json(self.stage1_result_path)
62
+ self.box_align_args = params['box_align']['args']
63
+
64
+ self.multiclass = params['model']['args']['multi_class']
65
+ self.online_eval_only = False
66
+
67
+ def get_item_single_car(self, selected_cav_base, ego_cav_base, tpe='all', cav_id='car_0', online_eval=False):
68
+ """
69
+ Process a single CAV's information for the train/test pipeline.
70
+
71
+
72
+ Parameters
73
+ ----------
74
+ selected_cav_base : dict
75
+ The dictionary contains a single CAV's raw information.
76
+ including 'params', 'camera_data'
77
+ ego_pose : list, length 6
78
+ The ego vehicle lidar pose under world coordinate.
79
+ ego_pose_clean : list, length 6
80
+ only used for gt box generation
81
+
82
+ Returns
83
+ -------
84
+ selected_cav_processed : dict
85
+ The dictionary contains the cav's processed information.
86
+ """
87
+ selected_cav_processed = {}
88
+ ego_pose, ego_pose_clean = ego_cav_base['params']['lidar_pose'], ego_cav_base['params']['lidar_pose_clean']
89
+
90
+ # calculate the transformation matrix
91
+ transformation_matrix = \
92
+ x1_to_x2(selected_cav_base['params']['lidar_pose'],
93
+ ego_pose) # T_ego_cav
94
+ transformation_matrix_clean = \
95
+ x1_to_x2(selected_cav_base['params']['lidar_pose_clean'],
96
+ ego_pose_clean)
97
+
98
+ # lidar
99
+ if tpe == 'all':
100
+ if self.load_lidar_file or self.visualize:
101
+ # process lidar
102
+ lidar_np = selected_cav_base['lidar_np']
103
+ lidar_np = shuffle_points(lidar_np)
104
+ # remove points that hit itself
105
+ if not cav_id.startswith('rsu'):
106
+ lidar_np = mask_ego_points_v2(lidar_np)
107
+ # project the lidar to ego space
108
+ # x,y,z in ego space
109
+ projected_lidar = \
110
+ box_utils.project_points_by_matrix_torch(lidar_np[:, :3],
111
+ transformation_matrix)
112
+ if self.proj_first:
113
+ lidar_np[:, :3] = projected_lidar
114
+
115
+ if self.visualize:
116
+ # filter lidar
117
+ selected_cav_processed.update({'projected_lidar': projected_lidar})
118
+
119
+ if self.kd_flag:
120
+ lidar_proj_np = copy.deepcopy(lidar_np)
121
+ lidar_proj_np[:,:3] = projected_lidar
122
+
123
+ selected_cav_processed.update({'projected_lidar': lidar_proj_np})
124
+
125
+ processed_lidar = self.pre_processor.preprocess(lidar_np)
126
+ selected_cav_processed.update({'processed_features': processed_lidar})
127
+
128
+ if True: # not online_eval:
129
+ # generate targets label single GT, note the reference pose is itself.
130
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center(
131
+ [selected_cav_base], selected_cav_base['params']['lidar_pose']
132
+ )
133
+ label_dict = {}
134
+ if tpe == 'all':
135
+ # unused label
136
+ if False:
137
+ label_dict = self.post_processor.generate_label(
138
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
139
+ )
140
+ selected_cav_processed.update({
141
+ "single_label_dict": label_dict,
142
+ "single_object_bbx_center": object_bbx_center,
143
+ "single_object_bbx_mask": object_bbx_mask})
144
+
145
+ if tpe == 'all':
146
+ # camera
147
+ if self.load_camera_file:
148
+ camera_data_list = selected_cav_base["camera_data"]
149
+
150
+ params = selected_cav_base["params"]
151
+ imgs = []
152
+ rots = []
153
+ trans = []
154
+ intrins = []
155
+ extrinsics = []
156
+ post_rots = []
157
+ post_trans = []
158
+
159
+ for idx, img in enumerate(camera_data_list):
160
+ camera_to_lidar, camera_intrinsic = self.get_ext_int(params, idx)
161
+
162
+ intrin = torch.from_numpy(camera_intrinsic)
163
+ rot = torch.from_numpy(
164
+ camera_to_lidar[:3, :3]
165
+ ) # R_wc, we consider world-coord is the lidar-coord
166
+ tran = torch.from_numpy(camera_to_lidar[:3, 3]) # T_wc
167
+
168
+ post_rot = torch.eye(2)
169
+ post_tran = torch.zeros(2)
170
+
171
+ img_src = [img]
172
+
173
+ # depth
174
+ if self.load_depth_file:
175
+ depth_img = selected_cav_base["depth_data"][idx]
176
+ img_src.append(depth_img)
177
+ else:
178
+ depth_img = None
179
+
180
+ # data augmentation
181
+ resize, resize_dims, crop, flip, rotate = sample_augmentation(
182
+ self.data_aug_conf, self.train
183
+ )
184
+ img_src, post_rot2, post_tran2 = img_transform(
185
+ img_src,
186
+ post_rot,
187
+ post_tran,
188
+ resize=resize,
189
+ resize_dims=resize_dims,
190
+ crop=crop,
191
+ flip=flip,
192
+ rotate=rotate,
193
+ )
194
+ # for convenience, make augmentation matrices 3x3
195
+ post_tran = torch.zeros(3)
196
+ post_rot = torch.eye(3)
197
+ post_tran[:2] = post_tran2
198
+ post_rot[:2, :2] = post_rot2
199
+
200
+ # decouple RGB and Depth
201
+
202
+ img_src[0] = normalize_img(img_src[0])
203
+ if self.load_depth_file:
204
+ img_src[1] = img_to_tensor(img_src[1]) * 255
205
+
206
+ imgs.append(torch.cat(img_src, dim=0))
207
+ intrins.append(intrin)
208
+ extrinsics.append(torch.from_numpy(camera_to_lidar))
209
+ rots.append(rot)
210
+ trans.append(tran)
211
+ post_rots.append(post_rot)
212
+ post_trans.append(post_tran)
213
+
214
+
215
+ selected_cav_processed.update(
216
+ {
217
+ "image_inputs":
218
+ {
219
+ "imgs": torch.stack(imgs), # [Ncam, 3or4, H, W]
220
+ "intrins": torch.stack(intrins),
221
+ "extrinsics": torch.stack(extrinsics),
222
+ "rots": torch.stack(rots),
223
+ "trans": torch.stack(trans),
224
+ "post_rots": torch.stack(post_rots),
225
+ "post_trans": torch.stack(post_trans),
226
+ }
227
+ }
228
+ )
229
+
230
+ # anchor box
231
+ selected_cav_processed.update({"anchor_box": self.anchor_box})
232
+
233
+ if True: # not online_eval:
234
+ # note the reference pose ego
235
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center([selected_cav_base],
236
+ ego_pose_clean)
237
+ selected_cav_processed.update(
238
+ {
239
+ "object_bbx_center": object_bbx_center[object_bbx_mask == 1],
240
+ "object_bbx_mask": object_bbx_mask,
241
+ "object_ids": object_ids,
242
+ }
243
+ )
244
+ selected_cav_processed.update(
245
+ {
246
+ 'transformation_matrix': transformation_matrix,
247
+ 'transformation_matrix_clean': transformation_matrix_clean
248
+ }
249
+ )
250
+
251
+
252
+ return selected_cav_processed
253
+
254
+ def __getitem__(self, idx, extra_source=None, data_dir=None, plan_without_perception_gt=True):
255
+ if (data_dir is not None) and (plan_without_perception_gt):
256
+ extra_source=1
257
+ object_bbx_center_list = []
258
+ object_bbx_mask_list = []
259
+ object_id_dict = {}
260
+
261
+ object_bbx_center_list_single = []
262
+ object_bbx_mask_list_single = []
263
+
264
+
265
+ output_dict = {}
266
+ for tpe in ['all', 0, 1, 3]:
267
+ output_single_class = self.__getitem_single_class__(idx, tpe, extra_source, data_dir)
268
+ output_dict[tpe] = output_single_class
269
+ if tpe == 'all':
270
+ continue
271
+ elif tpe == 'all' and extra_source!=None:
272
+ break
273
+ object_bbx_center_list.append(output_single_class['ego']['object_bbx_center'])
274
+ object_bbx_mask_list.append(output_single_class['ego']['object_bbx_mask'])
275
+ if self.supervise_single:
276
+ object_bbx_center_list_single.append(output_single_class['ego']['single_object_bbx_center_torch'])
277
+ object_bbx_mask_list_single.append(output_single_class['ego']['single_object_bbx_mask_torch'])
278
+
279
+ object_id_dict[tpe] = output_single_class['ego']['object_ids']
280
+
281
+ if True: # self.multiclass and extra_source==None:
282
+ output_dict['all']['ego']['object_bbx_center'] = np.stack(object_bbx_center_list, axis=0)
283
+ output_dict['all']['ego']['object_bbx_mask'] = np.stack(object_bbx_mask_list, axis=0)
284
+ if self.supervise_single:
285
+ output_dict['all']['ego']['single_object_bbx_center_torch'] = torch.stack(object_bbx_center_list_single, axis=1)
286
+ output_dict['all']['ego']['single_object_bbx_mask_torch'] = torch.stack(object_bbx_mask_list_single, axis=1)
287
+
288
+ output_dict['all']['ego']['object_ids'] = object_id_dict
289
+ # print('finish get item')
290
+ return output_dict['all']
291
+
292
+ def __getitem_single_class__(self, idx, tpe=None, extra_source=None, data_dir=None):
293
+
294
+ if extra_source is None and data_dir is None:
295
+ base_data_dict = self.retrieve_base_data(idx, tpe)
296
+ elif data_dir is not None:
297
+ base_data_dict = self.retrieve_base_data(idx=None, tpe=tpe, data_dir=data_dir)
298
+ elif extra_source is not None:
299
+ base_data_dict = self.retrieve_base_data(idx=None, tpe=tpe, extra_source=extra_source)
300
+
301
+ base_data_dict = add_noise_data_dict_asymmetric(base_data_dict,self.params['noise_setting'])
302
+ processed_data_dict = OrderedDict()
303
+ processed_data_dict['ego'] = {}
304
+
305
+ ego_id = -1
306
+ ego_lidar_pose = []
307
+ ego_cav_base = None
308
+
309
+ # first find the ego vehicle's lidar pose
310
+ for cav_id, cav_content in base_data_dict.items():
311
+ if cav_content['ego']:
312
+ ego_id = cav_id
313
+ ego_lidar_pose = cav_content['params']['lidar_pose']
314
+ ego_cav_base = cav_content
315
+ break
316
+
317
+ assert cav_id == list(base_data_dict.keys())[
318
+ 0], "The first element in the OrderedDict must be ego"
319
+ assert ego_id != -1
320
+ assert len(ego_lidar_pose) > 0
321
+
322
+ agents_image_inputs = []
323
+ processed_features = []
324
+ object_stack = []
325
+ object_id_stack = []
326
+ single_label_list = []
327
+ single_object_bbx_center_list = []
328
+ single_object_bbx_mask_list = []
329
+ too_far = []
330
+ lidar_pose_list = []
331
+ lidar_pose_clean_list = []
332
+ cav_id_list = []
333
+ projected_lidar_clean_list = [] # disconet
334
+
335
+ if self.visualize or self.kd_flag:
336
+ projected_lidar_stack = []
337
+
338
+ # loop over all CAVs to process information
339
+ for cav_id, selected_cav_base in base_data_dict.items():
340
+ # check if the cav is within the communication range with ego
341
+ distance = \
342
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
343
+ ego_lidar_pose[0]) ** 2 + (
344
+ selected_cav_base['params'][
345
+ 'lidar_pose'][1] - ego_lidar_pose[
346
+ 1]) ** 2)
347
+
348
+ # if distance is too far, we will just skip this agent
349
+ if distance > self.params['comm_range']:
350
+ too_far.append(cav_id)
351
+ continue
352
+
353
+ lidar_pose_clean_list.append(selected_cav_base['params']['lidar_pose_clean'])
354
+ lidar_pose_list.append(selected_cav_base['params']['lidar_pose']) # 6dof pose
355
+ cav_id_list.append(cav_id)
356
+
357
+ for cav_id in too_far:
358
+ base_data_dict.pop(cav_id)
359
+
360
+ ########## Updated by Yifan Lu 2022.1.26 ############
361
+ # box align to correct pose.
362
+ # stage1_content contains all agent. Even out of comm range.
363
+ if self.box_align and str(idx) in self.stage1_result.keys(): # False
364
+ from opencood.models.sub_modules.box_align_v2 import box_alignment_relative_sample_np
365
+ stage1_content = self.stage1_result[str(idx)]
366
+ if stage1_content is not None:
367
+ all_agent_id_list = stage1_content['cav_id_list'] # include those out of range
368
+ all_agent_corners_list = stage1_content['pred_corner3d_np_list']
369
+ all_agent_uncertainty_list = stage1_content['uncertainty_np_list']
370
+
371
+ cur_agent_id_list = cav_id_list
372
+ cur_agent_pose = [base_data_dict[cav_id]['params']['lidar_pose'] for cav_id in cav_id_list]
373
+ cur_agnet_pose = np.array(cur_agent_pose)
374
+ cur_agent_in_all_agent = [all_agent_id_list.index(cur_agent) for cur_agent in cur_agent_id_list] # indexing current agent in `all_agent_id_list`
375
+
376
+ pred_corners_list = [np.array(all_agent_corners_list[cur_in_all_ind], dtype=np.float64)
377
+ for cur_in_all_ind in cur_agent_in_all_agent]
378
+ uncertainty_list = [np.array(all_agent_uncertainty_list[cur_in_all_ind], dtype=np.float64)
379
+ for cur_in_all_ind in cur_agent_in_all_agent]
380
+
381
+ if sum([len(pred_corners) for pred_corners in pred_corners_list]) != 0:
382
+ refined_pose = box_alignment_relative_sample_np(pred_corners_list,
383
+ cur_agnet_pose,
384
+ uncertainty_list=uncertainty_list,
385
+ **self.box_align_args)
386
+ cur_agnet_pose[:,[0,1,4]] = refined_pose
387
+
388
+ for i, cav_id in enumerate(cav_id_list):
389
+ lidar_pose_list[i] = cur_agnet_pose[i].tolist()
390
+ base_data_dict[cav_id]['params']['lidar_pose'] = cur_agnet_pose[i].tolist()
391
+
392
+
393
+
394
+ pairwise_t_matrix = \
395
+ get_pairwise_transformation_asymmetric(base_data_dict,
396
+ self.max_cav,
397
+ self.proj_first)
398
+
399
+ lidar_poses = np.array(lidar_pose_list).reshape(-1, 6) # [N_cav, 6]
400
+ lidar_poses_clean = np.array(lidar_pose_clean_list).reshape(-1, 6) # [N_cav, 6]
401
+
402
+ # merge preprocessed features from different cavs into the same dict
403
+ cav_num = len(cav_id_list)
404
+
405
+ # heterogeneous
406
+ if self.heterogeneous:
407
+ lidar_agent, camera_agent = self.selector.select_agent(idx)
408
+ lidar_agent = lidar_agent[:cav_num]
409
+ processed_data_dict['ego'].update({"lidar_agent": lidar_agent})
410
+
411
+ for _i, cav_id in enumerate(cav_id_list):
412
+ selected_cav_base = base_data_dict[cav_id]
413
+
414
+ # dynamic object center generator! for heterogeneous input
415
+ if (not self.visualize) and self.heterogeneous and lidar_agent[_i]:
416
+ self.generate_object_center = self.generate_object_center_lidar
417
+ elif (not self.visualize) and self.heterogeneous and (not lidar_agent[_i]):
418
+ self.generate_object_center = self.generate_object_center_camera
419
+
420
+ selected_cav_processed = self.get_item_single_car(
421
+ selected_cav_base,
422
+ ego_cav_base,
423
+ tpe,
424
+ cav_id,
425
+ extra_source!=None)
426
+
427
+ if True: #extra_source==None:
428
+ object_stack.append(selected_cav_processed['object_bbx_center'])
429
+ object_id_stack += selected_cav_processed['object_ids']
430
+ if tpe == 'all':
431
+ if self.load_lidar_file:
432
+ processed_features.append(
433
+ selected_cav_processed['processed_features'])
434
+ if self.load_camera_file:
435
+ agents_image_inputs.append(
436
+ selected_cav_processed['image_inputs'])
437
+
438
+ if self.visualize or self.kd_flag:
439
+ projected_lidar_stack.append(
440
+ selected_cav_processed['projected_lidar'])
441
+
442
+ if True: #self.supervise_single and extra_source==None:
443
+ single_label_list.append(selected_cav_processed['single_label_dict'])
444
+ single_object_bbx_center_list.append(selected_cav_processed['single_object_bbx_center'])
445
+ single_object_bbx_mask_list.append(selected_cav_processed['single_object_bbx_mask'])
446
+
447
+ # generate single view GT label
448
+ if True: # self.supervise_single and extra_source==None:
449
+ single_label_dicts = {}
450
+ if tpe == 'all':
451
+ # unused label
452
+ if False:
453
+ single_label_dicts = self.post_processor.collate_batch(single_label_list)
454
+ single_object_bbx_center = torch.from_numpy(np.array(single_object_bbx_center_list))
455
+ single_object_bbx_mask = torch.from_numpy(np.array(single_object_bbx_mask_list))
456
+ processed_data_dict['ego'].update({
457
+ "single_label_dict_torch": single_label_dicts,
458
+ "single_object_bbx_center_torch": single_object_bbx_center,
459
+ "single_object_bbx_mask_torch": single_object_bbx_mask,
460
+ })
461
+
462
+ if self.kd_flag:
463
+ stack_lidar_np = np.vstack(projected_lidar_stack)
464
+ stack_lidar_np = mask_points_by_range(stack_lidar_np,
465
+ self.params['preprocess'][
466
+ 'cav_lidar_range'])
467
+ stack_feature_processed = self.pre_processor.preprocess(stack_lidar_np)
468
+ processed_data_dict['ego'].update({'teacher_processed_lidar':
469
+ stack_feature_processed})
470
+
471
+ if True: # extra_source is None:
472
+ # exclude all repetitive objects
473
+ unique_indices = \
474
+ [object_id_stack.index(x) for x in set(object_id_stack)]
475
+ object_stack = np.vstack(object_stack)
476
+ object_stack = object_stack[unique_indices]
477
+
478
+ # make sure bounding boxes across all frames have the same number
479
+ object_bbx_center = \
480
+ np.zeros((self.params['postprocess']['max_num'], 7))
481
+ mask = np.zeros(self.params['postprocess']['max_num'])
482
+ object_bbx_center[:object_stack.shape[0], :] = object_stack
483
+ mask[:object_stack.shape[0]] = 1
484
+
485
+ processed_data_dict['ego'].update(
486
+ {'object_bbx_center': object_bbx_center, # (100,7)
487
+ 'object_bbx_mask': mask, # (100,)
488
+ 'object_ids': [object_id_stack[i] for i in unique_indices],
489
+ }
490
+ )
491
+
492
+ # generate targets label
493
+ label_dict = {}
494
+ if tpe == 'all':
495
+ # unused label
496
+ if False:
497
+ label_dict = \
498
+ self.post_processor.generate_label(
499
+ gt_box_center=object_bbx_center,
500
+ anchors=self.anchor_box,
501
+ mask=mask)
502
+
503
+ processed_data_dict['ego'].update(
504
+ {
505
+ 'anchor_box': self.anchor_box,
506
+ 'label_dict': label_dict,
507
+ 'cav_num': cav_num,
508
+ 'pairwise_t_matrix': pairwise_t_matrix,
509
+ 'lidar_poses_clean': lidar_poses_clean,
510
+ 'lidar_poses': lidar_poses})
511
+
512
+ if tpe == 'all':
513
+ if self.load_lidar_file:
514
+ merged_feature_dict = merge_features_to_dict(processed_features)
515
+ processed_data_dict['ego'].update({'processed_lidar': merged_feature_dict})
516
+ if self.load_camera_file:
517
+ merged_image_inputs_dict = merge_features_to_dict(agents_image_inputs, merge='stack')
518
+ processed_data_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
519
+
520
+ if self.visualize:
521
+ processed_data_dict['ego'].update({'origin_lidar':
522
+ # projected_lidar_stack})
523
+ np.vstack(
524
+ projected_lidar_stack)})
525
+ processed_data_dict['ego'].update({'lidar_len': [len(projected_lidar_stack[i]) for i in range(len(projected_lidar_stack))]})
526
+
527
+
528
+ processed_data_dict['ego'].update({'sample_idx': idx,
529
+ 'cav_id_list': cav_id_list})
530
+
531
+ img_front_list = []
532
+ img_left_list = []
533
+ img_right_list = []
534
+ BEV_list = []
535
+
536
+ if self.visualize:
537
+ for car_id in base_data_dict:
538
+ if not base_data_dict[car_id]['ego'] == True:
539
+ continue
540
+ if 'rgb_front' in base_data_dict[car_id] and 'rgb_left' in base_data_dict[car_id] and 'rgb_right' in base_data_dict[car_id] and 'BEV' in base_data_dict[car_id] :
541
+ img_front_list.append(base_data_dict[car_id]['rgb_front'])
542
+ img_left_list.append(base_data_dict[car_id]['rgb_left'])
543
+ img_right_list.append(base_data_dict[car_id]['rgb_right'])
544
+ BEV_list.append(base_data_dict[car_id]['BEV'])
545
+ processed_data_dict['ego'].update({'img_front': img_front_list,
546
+ 'img_left': img_left_list,
547
+ 'img_right': img_right_list,
548
+ 'BEV': BEV_list})
549
+ processed_data_dict['ego'].update({'scene_dict': base_data_dict['car_0']['scene_dict'],
550
+ 'frame_id': base_data_dict['car_0']['frame_id'],
551
+ })
552
+
553
+
554
+ # TODO: LSS debug
555
+ processed_data_dict['ego'].update({"det_data": base_data_dict['car_0']['det_data']})
556
+ detmap_pose_list = []
557
+ for car_id in base_data_dict:
558
+ detmap_pose_list.append(base_data_dict[car_id]['detmap_pose'])
559
+ detmap_pose_list = torch.from_numpy(np.array(detmap_pose_list))
560
+ processed_data_dict['ego'].update({"detmap_pose": detmap_pose_list})
561
+ ##
562
+
563
+ return processed_data_dict
564
+
565
+
566
+ def collate_batch_train(self, batch, online_eval_only=False):
567
+ # Intermediate fusion is different the other two
568
+ output_dict = {'ego': {}}
569
+
570
+ object_bbx_center = []
571
+ object_bbx_mask = []
572
+ object_ids = []
573
+ processed_lidar_list = []
574
+ image_inputs_list = []
575
+ # used to record different scenario
576
+ record_len = []
577
+ label_dict_list = []
578
+ lidar_pose_list = []
579
+ origin_lidar = []
580
+ lidar_len = []
581
+ lidar_pose_clean_list = []
582
+
583
+ # heterogeneous
584
+ lidar_agent_list = []
585
+
586
+ # pairwise transformation matrix
587
+ pairwise_t_matrix_list = []
588
+
589
+ # disconet
590
+ teacher_processed_lidar_list = []
591
+
592
+ # image
593
+ img_front = []
594
+ img_left = []
595
+ img_right = []
596
+ BEV = []
597
+
598
+ dict_list = []
599
+
600
+ # TODO: LSS debug
601
+ det_data = []
602
+ detmap_pose = []
603
+
604
+ ### 2022.10.10 single gt ####
605
+ if self.supervise_single:
606
+ pos_equal_one_single = []
607
+ neg_equal_one_single = []
608
+ targets_single = []
609
+ object_bbx_center_single = []
610
+ object_bbx_mask_single = []
611
+
612
+ for i in range(len(batch)):
613
+ ego_dict = batch[i]['ego']
614
+ det_data.append(torch.from_numpy(ego_dict['det_data']).unsqueeze(0))
615
+ detmap_pose.append(ego_dict['detmap_pose'])
616
+ if not online_eval_only:
617
+ object_bbx_center.append(ego_dict['object_bbx_center'])
618
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
619
+ object_ids.append(ego_dict['object_ids'])
620
+ else:
621
+ object_ids.append(None)
622
+ lidar_pose_list.append(ego_dict['lidar_poses']) # ego_dict['lidar_pose'] is np.ndarray [N,6]
623
+ lidar_pose_clean_list.append(ego_dict['lidar_poses_clean'])
624
+ if self.load_lidar_file:
625
+ processed_lidar_list.append(ego_dict['processed_lidar'])
626
+ if self.load_camera_file:
627
+ image_inputs_list.append(ego_dict['image_inputs']) # different cav_num, ego_dict['image_inputs'] is dict.
628
+
629
+ record_len.append(ego_dict['cav_num'])
630
+ label_dict_list.append(ego_dict['label_dict'])
631
+ pairwise_t_matrix_list.append(ego_dict['pairwise_t_matrix'])
632
+
633
+ dict_list.append([ego_dict['scene_dict'], ego_dict['frame_id']])
634
+
635
+ if self.visualize:
636
+ origin_lidar.append(ego_dict['origin_lidar'])
637
+ lidar_len.append(ego_dict['lidar_len'])
638
+ if len(ego_dict['img_front']) > 0 and len(ego_dict['img_right']) > 0 and len(ego_dict['img_left']) > 0 and len(ego_dict['BEV']) > 0:
639
+ img_front.append(ego_dict['img_front'][0])
640
+ img_left.append(ego_dict['img_left'][0])
641
+ img_right.append(ego_dict['img_right'][0])
642
+ BEV.append(ego_dict['BEV'][0])
643
+
644
+
645
+ if self.kd_flag:
646
+ teacher_processed_lidar_list.append(ego_dict['teacher_processed_lidar'])
647
+
648
+ ### 2022.10.10 single gt ####
649
+ if self.supervise_single and not online_eval_only:
650
+ # unused label
651
+ if False:
652
+ pos_equal_one_single.append(ego_dict['single_label_dict_torch']['pos_equal_one'])
653
+ neg_equal_one_single.append(ego_dict['single_label_dict_torch']['neg_equal_one'])
654
+ targets_single.append(ego_dict['single_label_dict_torch']['targets'])
655
+ object_bbx_center_single.append(ego_dict['single_object_bbx_center_torch'])
656
+ object_bbx_mask_single.append(ego_dict['single_object_bbx_mask_torch'])
657
+
658
+ # heterogeneous
659
+ if self.heterogeneous:
660
+ lidar_agent_list.append(ego_dict['lidar_agent'])
661
+
662
+ # convert to numpy, (B, max_num, 7)
663
+ if not online_eval_only:
664
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
665
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
666
+ else:
667
+ object_bbx_center = None
668
+ object_bbx_mask = None
669
+
670
+ if self.load_lidar_file:
671
+ merged_feature_dict = merge_features_to_dict(processed_lidar_list)
672
+
673
+ if self.heterogeneous:
674
+ lidar_agent = np.concatenate(lidar_agent_list)
675
+ lidar_agent_idx = lidar_agent.nonzero()[0].tolist()
676
+ for k, v in merged_feature_dict.items(): # 'voxel_features' 'voxel_num_points' 'voxel_coords'
677
+ merged_feature_dict[k] = [v[index] for index in lidar_agent_idx]
678
+
679
+ if not self.heterogeneous or (self.heterogeneous and sum(lidar_agent) != 0):
680
+ processed_lidar_torch_dict = \
681
+ self.pre_processor.collate_batch(merged_feature_dict)
682
+ output_dict['ego'].update({'processed_lidar': processed_lidar_torch_dict})
683
+
684
+ if self.load_camera_file:
685
+ merged_image_inputs_dict = merge_features_to_dict(image_inputs_list, merge='cat')
686
+
687
+ if self.heterogeneous:
688
+ lidar_agent = np.concatenate(lidar_agent_list)
689
+ camera_agent = 1 - lidar_agent
690
+ camera_agent_idx = camera_agent.nonzero()[0].tolist()
691
+ if sum(camera_agent) != 0:
692
+ for k, v in merged_image_inputs_dict.items(): # 'imgs' 'rots' 'trans' ...
693
+ merged_image_inputs_dict[k] = torch.stack([v[index] for index in camera_agent_idx])
694
+
695
+ if not self.heterogeneous or (self.heterogeneous and sum(camera_agent) != 0):
696
+ output_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
697
+
698
+ record_len = torch.from_numpy(np.array(record_len, dtype=int))
699
+ lidar_pose = torch.from_numpy(np.concatenate(lidar_pose_list, axis=0))
700
+ lidar_pose_clean = torch.from_numpy(np.concatenate(lidar_pose_clean_list, axis=0))
701
+ # unused label
702
+ label_torch_dict = {}
703
+ if False:
704
+ label_torch_dict = \
705
+ self.post_processor.collate_batch(label_dict_list)
706
+
707
+ # for centerpoint
708
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
709
+ 'object_bbx_mask': object_bbx_mask})
710
+
711
+ # (B, max_cav)
712
+ pairwise_t_matrix = torch.from_numpy(np.array(pairwise_t_matrix_list))
713
+
714
+ # add pairwise_t_matrix to label dict
715
+ label_torch_dict['pairwise_t_matrix'] = pairwise_t_matrix
716
+ label_torch_dict['record_len'] = record_len
717
+
718
+
719
+ # object id is only used during inference, where batch size is 1.
720
+ # so here we only get the first element.
721
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
722
+ 'object_bbx_mask': object_bbx_mask,
723
+ 'record_len': record_len,
724
+ 'label_dict': label_torch_dict,
725
+ 'object_ids': object_ids[0],
726
+ 'pairwise_t_matrix': pairwise_t_matrix,
727
+ 'lidar_pose_clean': lidar_pose_clean,
728
+ 'lidar_pose': lidar_pose,
729
+ 'anchor_box': self.anchor_box_torch})
730
+
731
+
732
+ output_dict['ego'].update({'dict_list': dict_list})
733
+
734
+ if self.visualize:
735
+ origin_lidar = torch.from_numpy(np.array(origin_lidar))
736
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
737
+ lidar_len = np.array(lidar_len)
738
+ output_dict['ego'].update({'lidar_len': lidar_len})
739
+ output_dict['ego'].update({'img_front': img_front})
740
+ output_dict['ego'].update({'img_right': img_right})
741
+ output_dict['ego'].update({'img_left': img_left})
742
+ output_dict['ego'].update({'BEV': BEV})
743
+
744
+ if self.kd_flag:
745
+ teacher_processed_lidar_torch_dict = \
746
+ self.pre_processor.collate_batch(teacher_processed_lidar_list)
747
+ output_dict['ego'].update({'teacher_processed_lidar':teacher_processed_lidar_torch_dict})
748
+
749
+
750
+ if self.supervise_single and not online_eval_only:
751
+ output_dict['ego'].update({
752
+ "label_dict_single":{
753
+ # for centerpoint
754
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
755
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
756
+ },
757
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
758
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
759
+ })
760
+
761
+ if self.heterogeneous:
762
+ output_dict['ego'].update({
763
+ "lidar_agent_record": torch.from_numpy(np.concatenate(lidar_agent_list)) # [0,1,1,0,1...]
764
+ })
765
+
766
+ # TODO: LSS debug
767
+ det_data = torch.cat(det_data, dim=0)
768
+ detmap_pose = torch.cat(detmap_pose, dim=0)
769
+ output_dict['ego'].update({'detmap_pose': detmap_pose})
770
+
771
+ output_dict['ego']['label_dict'].update({
772
+ 'det_data': det_data})
773
+ return output_dict
774
+
775
+ def collate_batch_test(self, batch, online_eval_only=False):
776
+
777
+ self.online_eval_only = online_eval_only
778
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
779
+ output_dict = self.collate_batch_train(batch, online_eval_only)
780
+ if output_dict is None:
781
+ return None
782
+
783
+ # check if anchor box in the batch
784
+ if batch[0]['ego']['anchor_box'] is not None:
785
+ output_dict['ego'].update({'anchor_box':
786
+ self.anchor_box_torch})
787
+
788
+ # save the transformation matrix (4, 4) to ego vehicle
789
+ # transformation is only used in post process (no use.)
790
+ # we all predict boxes in ego coord.
791
+ transformation_matrix_torch = \
792
+ torch.from_numpy(np.identity(4)).float()
793
+ transformation_matrix_clean_torch = \
794
+ torch.from_numpy(np.identity(4)).float()
795
+
796
+ output_dict['ego'].update({'transformation_matrix':
797
+ transformation_matrix_torch,
798
+ 'transformation_matrix_clean':
799
+ transformation_matrix_clean_torch,})
800
+
801
+ output_dict['ego'].update({
802
+ "sample_idx": batch[0]['ego']['sample_idx'],
803
+ "cav_id_list": batch[0]['ego']['cav_id_list']
804
+ })
805
+
806
+ return output_dict
807
+
808
+
809
+ def post_process(self, data_dict, output_dict):
810
+ """
811
+ Process the outputs of the model to 2D/3D bounding box.
812
+
813
+ Parameters
814
+ ----------
815
+ data_dict : dict
816
+ The dictionary containing the origin input data of model.
817
+
818
+ output_dict :dict
819
+ The dictionary containing the output of the model.
820
+
821
+ Returns
822
+ -------
823
+ pred_box_tensor : torch.Tensor
824
+ The tensor of prediction bounding box after NMS.
825
+ gt_box_tensor : torch.Tensor
826
+ The tensor of gt bounding box.
827
+ """
828
+ pred_box_tensor, pred_score = \
829
+ self.post_processor.post_process(data_dict, output_dict)
830
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
831
+
832
+ return pred_box_tensor, pred_score, gt_box_tensor
833
+
834
+ def post_process_multiclass(self, data_dict, output_dict, online_eval_only=False):
835
+ """
836
+ Process the outputs of the model to 2D/3D bounding box.
837
+
838
+ Parameters
839
+ ----------
840
+ data_dict : dict
841
+ The dictionary containing the origin input data of model.
842
+
843
+ output_dict :dict
844
+ The dictionary containing the output of the model.
845
+
846
+ Returns
847
+ -------
848
+ pred_box_tensor : torch.Tensor
849
+ The tensor of prediction bounding box after NMS.
850
+ gt_box_tensor : torch.Tensor
851
+ The tensor of gt bounding box.
852
+ """
853
+
854
+ if online_eval_only == False:
855
+ online_eval_only = self.online_eval_only
856
+
857
+ num_class = output_dict['ego']['cls_preds'].shape[1]
858
+
859
+
860
+ pred_box_tensor_list = []
861
+ pred_score_list = []
862
+ gt_box_tensor_list = []
863
+
864
+ num_list = [0,1,3]
865
+
866
+ for i in range(num_class):
867
+ data_dict_single = copy.deepcopy(data_dict)
868
+ output_dict_single = copy.deepcopy(output_dict)
869
+ if not online_eval_only:
870
+ data_dict_single['ego']['object_bbx_center'] = data_dict['ego']['object_bbx_center'][:,i,:,:]
871
+ data_dict_single['ego']['object_bbx_mask'] = data_dict['ego']['object_bbx_mask'][:,i,:]
872
+ data_dict_single['ego']['object_ids'] = data_dict['ego']['object_ids'][num_list[i]]
873
+
874
+ output_dict_single['ego']['cls_preds'] = output_dict['ego']['cls_preds'][:,i:i+1,:,:]
875
+ output_dict_single['ego']['reg_preds'] = output_dict['ego']['reg_preds_multiclass'][:,i,:,:]
876
+
877
+ pred_box_tensor, pred_score = \
878
+ self.post_processor.post_process(data_dict_single, output_dict_single)
879
+ if not online_eval_only:
880
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict_single)
881
+ else:
882
+ gt_box_tensor = None
883
+
884
+ pred_box_tensor_list.append(pred_box_tensor)
885
+ pred_score_list.append(pred_score)
886
+ gt_box_tensor_list.append(gt_box_tensor)
887
+
888
+ return pred_box_tensor_list, pred_score_list, gt_box_tensor_list
889
+
890
+ return IntermediatemulticlassFusionDataset
891
+
892
+
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/late_fusion_dataset.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # late fusion dataset
2
+ import random
3
+ import math
4
+ from collections import OrderedDict
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import copy
9
+ from icecream import ic
10
+ from PIL import Image
11
+ import pickle as pkl
12
+ from opencood.utils import box_utils as box_utils
13
+ from opencood.data_utils.pre_processor import build_preprocessor
14
+ from opencood.data_utils.post_processor import build_postprocessor
15
+ from opencood.utils.camera_utils import (
16
+ sample_augmentation,
17
+ img_transform,
18
+ normalize_img,
19
+ img_to_tensor,
20
+ )
21
+ from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
22
+ from opencood.utils.transformation_utils import x1_to_x2
23
+ from opencood.utils.pose_utils import add_noise_data_dict
24
+ from opencood.utils.pcd_utils import (
25
+ mask_points_by_range,
26
+ mask_ego_points,
27
+ shuffle_points,
28
+ downsample_lidar_minimum,
29
+ )
30
+
31
+
32
+ def getLateFusionDataset(cls):
33
+ """
34
+ cls: the Basedataset.
35
+ """
36
+ class LateFusionDataset(cls):
37
+ def __init__(self, params, visualize, train=True):
38
+ super().__init__(params, visualize, train)
39
+ self.anchor_box = self.post_processor.generate_anchor_box()
40
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
41
+
42
+ self.heterogeneous = False
43
+ if 'heter' in params:
44
+ self.heterogeneous = True
45
+
46
+ def __getitem__(self, idx):
47
+ base_data_dict = self.retrieve_base_data(idx)
48
+ if self.train:
49
+ reformat_data_dict = self.get_item_train(base_data_dict)
50
+ else:
51
+ reformat_data_dict = self.get_item_test(base_data_dict, idx)
52
+
53
+ return reformat_data_dict
54
+
55
+ def get_item_train(self, base_data_dict):
56
+ processed_data_dict = OrderedDict()
57
+ base_data_dict = add_noise_data_dict(
58
+ base_data_dict, self.params["noise_setting"]
59
+ )
60
+ # during training, we return a random cav's data
61
+ # only one vehicle is in processed_data_dict
62
+ if not self.visualize:
63
+ selected_cav_id, selected_cav_base = random.choice(
64
+ list(base_data_dict.items())
65
+ )
66
+ else:
67
+ selected_cav_id, selected_cav_base = list(base_data_dict.items())[0]
68
+
69
+ selected_cav_processed = self.get_item_single_car(selected_cav_base)
70
+ processed_data_dict.update({"ego": selected_cav_processed})
71
+
72
+ return processed_data_dict
73
+
74
+
75
+ def get_item_test(self, base_data_dict, idx):
76
+ """
77
+ processed_data_dict.keys() = ['ego', "650", "659", ...]
78
+ """
79
+ base_data_dict = add_noise_data_dict(base_data_dict,self.params['noise_setting'])
80
+
81
+ processed_data_dict = OrderedDict()
82
+ ego_id = -1
83
+ ego_lidar_pose = []
84
+ cav_id_list = []
85
+ lidar_pose_list = []
86
+
87
+ # first find the ego vehicle's lidar pose
88
+ for cav_id, cav_content in base_data_dict.items():
89
+ if cav_content['ego']:
90
+ ego_id = cav_id
91
+ ego_lidar_pose = cav_content['params']['lidar_pose']
92
+ ego_lidar_pose_clean = cav_content['params']['lidar_pose_clean']
93
+ break
94
+
95
+ assert ego_id != -1
96
+ assert len(ego_lidar_pose) > 0
97
+
98
+ # loop over all CAVs to process information
99
+ for cav_id, selected_cav_base in base_data_dict.items():
100
+ distance = \
101
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
102
+ ego_lidar_pose[0]) ** 2 + (
103
+ selected_cav_base['params'][
104
+ 'lidar_pose'][1] - ego_lidar_pose[
105
+ 1]) ** 2)
106
+ if distance > self.params['comm_range']:
107
+ continue
108
+ cav_id_list.append(cav_id)
109
+ lidar_pose_list.append(selected_cav_base['params']['lidar_pose'])
110
+
111
+ cav_id_list_newname = []
112
+ for cav_id in cav_id_list:
113
+ selected_cav_base = base_data_dict[cav_id]
114
+ # find the transformation matrix from current cav to ego.
115
+ cav_lidar_pose = selected_cav_base['params']['lidar_pose']
116
+ transformation_matrix = x1_to_x2(cav_lidar_pose, ego_lidar_pose)
117
+ cav_lidar_pose_clean = selected_cav_base['params']['lidar_pose_clean']
118
+ transformation_matrix_clean = x1_to_x2(cav_lidar_pose_clean, ego_lidar_pose_clean)
119
+
120
+ selected_cav_processed = \
121
+ self.get_item_single_car(selected_cav_base)
122
+ selected_cav_processed.update({'transformation_matrix': transformation_matrix,
123
+ 'transformation_matrix_clean': transformation_matrix_clean})
124
+ update_cav = "ego" if cav_id == ego_id else cav_id
125
+ processed_data_dict.update({update_cav: selected_cav_processed})
126
+ cav_id_list_newname.append(update_cav)
127
+
128
+ # heterogeneous
129
+ if self.heterogeneous:
130
+ processed_data_dict['ego']['idx'] = idx
131
+ processed_data_dict['ego']['cav_list'] = cav_id_list_newname
132
+
133
+ return processed_data_dict
134
+
135
+
136
+ def get_item_single_car(self, selected_cav_base):
137
+ """
138
+ Process a single CAV's information for the train/test pipeline.
139
+
140
+
141
+ Parameters
142
+ ----------
143
+ selected_cav_base : dict
144
+ The dictionary contains a single CAV's raw information.
145
+ including 'params', 'camera_data'
146
+
147
+ Returns
148
+ -------
149
+ selected_cav_processed : dict
150
+ The dictionary contains the cav's processed information.
151
+ """
152
+ selected_cav_processed = {}
153
+
154
+ # label
155
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center_single(
156
+ [selected_cav_base], selected_cav_base["params"]["lidar_pose_clean"]
157
+ )
158
+
159
+ # lidar
160
+ if self.load_lidar_file or self.visualize:
161
+ lidar_np = selected_cav_base['lidar_np']
162
+ lidar_np = shuffle_points(lidar_np)
163
+ lidar_np = mask_points_by_range(lidar_np,
164
+ self.params['preprocess'][
165
+ 'cav_lidar_range'])
166
+ # remove points that hit ego vehicle
167
+ lidar_np = mask_ego_points(lidar_np)
168
+
169
+ # data augmentation, seems very important for single agent training, because lack of data diversity.
170
+ # only work for lidar modality in training.
171
+ if not self.heterogeneous:
172
+ lidar_np, object_bbx_center, object_bbx_mask = \
173
+ self.augment(lidar_np, object_bbx_center, object_bbx_mask)
174
+
175
+ lidar_dict = self.pre_processor.preprocess(lidar_np)
176
+ selected_cav_processed.update({'processed_lidar': lidar_dict})
177
+
178
+
179
+
180
+
181
+ if self.visualize:
182
+ selected_cav_processed.update({'origin_lidar': lidar_np})
183
+
184
+ # camera
185
+ if self.load_camera_file:
186
+ # adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/data.py
187
+ camera_data_list = selected_cav_base["camera_data"]
188
+
189
+ params = selected_cav_base["params"]
190
+ imgs = []
191
+ rots = []
192
+ trans = []
193
+ intrins = []
194
+ extrinsics = [] # cam_to_lidar
195
+ post_rots = []
196
+ post_trans = []
197
+
198
+ for idx, img in enumerate(camera_data_list):
199
+ camera_to_lidar, camera_intrinsic = self.get_ext_int(params, idx)
200
+
201
+ intrin = torch.from_numpy(camera_intrinsic)
202
+ rot = torch.from_numpy(
203
+ camera_to_lidar[:3, :3]
204
+ ) # R_wc, we consider world-coord is the lidar-coord
205
+ tran = torch.from_numpy(camera_to_lidar[:3, 3]) # T_wc
206
+
207
+ post_rot = torch.eye(2)
208
+ post_tran = torch.zeros(2)
209
+
210
+ img_src = [img]
211
+
212
+ # depth
213
+ if self.load_depth_file:
214
+ depth_img = selected_cav_base["depth_data"][idx]
215
+ img_src.append(depth_img)
216
+ else:
217
+ depth_img = None
218
+
219
+ # data augmentation
220
+ resize, resize_dims, crop, flip, rotate = sample_augmentation(
221
+ self.data_aug_conf, self.train
222
+ )
223
+ img_src, post_rot2, post_tran2 = img_transform(
224
+ img_src,
225
+ post_rot,
226
+ post_tran,
227
+ resize=resize,
228
+ resize_dims=resize_dims,
229
+ crop=crop,
230
+ flip=flip,
231
+ rotate=rotate,
232
+ )
233
+ # for convenience, make augmentation matrices 3x3
234
+ post_tran = torch.zeros(3)
235
+ post_rot = torch.eye(3)
236
+ post_tran[:2] = post_tran2
237
+ post_rot[:2, :2] = post_rot2
238
+
239
+ img_src[0] = normalize_img(img_src[0])
240
+ if self.load_depth_file:
241
+ img_src[1] = img_to_tensor(img_src[1]) * 255
242
+
243
+ imgs.append(torch.cat(img_src, dim=0))
244
+ intrins.append(intrin)
245
+ extrinsics.append(torch.from_numpy(camera_to_lidar))
246
+ rots.append(rot)
247
+ trans.append(tran)
248
+ post_rots.append(post_rot)
249
+ post_trans.append(post_tran)
250
+
251
+ selected_cav_processed.update(
252
+ {
253
+ "image_inputs":
254
+ {
255
+ "imgs": torch.stack(imgs), # [N, 3or4, H, W]
256
+ "intrins": torch.stack(intrins),
257
+ "extrinsics": torch.stack(extrinsics),
258
+ "rots": torch.stack(rots),
259
+ "trans": torch.stack(trans),
260
+ "post_rots": torch.stack(post_rots),
261
+ "post_trans": torch.stack(post_trans),
262
+ }
263
+ }
264
+ )
265
+
266
+
267
+ selected_cav_processed.update(
268
+ {
269
+ "object_bbx_center": object_bbx_center,
270
+ "object_bbx_mask": object_bbx_mask,
271
+ "object_ids": object_ids,
272
+ }
273
+ )
274
+
275
+ # generate targets label
276
+ label_dict = self.post_processor.generate_label(
277
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
278
+ )
279
+ selected_cav_processed.update({"label_dict": label_dict})
280
+
281
+ return selected_cav_processed
282
+
283
+
284
+ def collate_batch_train(self, batch):
285
+ """
286
+ Customized collate function for pytorch dataloader during training
287
+ for early and late fusion dataset.
288
+
289
+ Parameters
290
+ ----------
291
+ batch : dict
292
+
293
+ Returns
294
+ -------
295
+ batch : dict
296
+ Reformatted batch.
297
+ """
298
+ # during training, we only care about ego.
299
+ output_dict = {'ego': {}}
300
+
301
+ object_bbx_center = []
302
+ object_bbx_mask = []
303
+ processed_lidar_list = []
304
+ label_dict_list = []
305
+ origin_lidar = []
306
+
307
+ for i in range(len(batch)):
308
+ ego_dict = batch[i]['ego']
309
+ object_bbx_center.append(ego_dict['object_bbx_center'])
310
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
311
+ label_dict_list.append(ego_dict['label_dict'])
312
+
313
+ if self.visualize:
314
+ origin_lidar.append(ego_dict['origin_lidar'])
315
+
316
+ # convert to numpy, (B, max_num, 7)
317
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
318
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
319
+ label_torch_dict = \
320
+ self.post_processor.collate_batch(label_dict_list)
321
+
322
+ # for centerpoint
323
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
324
+ 'object_bbx_mask': object_bbx_mask})
325
+
326
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
327
+ 'object_bbx_mask': object_bbx_mask,
328
+ 'anchor_box': torch.from_numpy(self.anchor_box),
329
+ 'label_dict': label_torch_dict})
330
+ if self.visualize:
331
+ origin_lidar = \
332
+ np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
333
+ origin_lidar = torch.from_numpy(origin_lidar)
334
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
335
+
336
+ if self.load_lidar_file:
337
+ for i in range(len(batch)):
338
+ processed_lidar_list.append(batch[i]['ego']['processed_lidar'])
339
+ processed_lidar_torch_dict = \
340
+ self.pre_processor.collate_batch(processed_lidar_list)
341
+ output_dict['ego'].update({'processed_lidar': processed_lidar_torch_dict})
342
+
343
+ if self.load_camera_file:
344
+ # collate ego camera information
345
+ imgs_batch = []
346
+ rots_batch = []
347
+ trans_batch = []
348
+ intrins_batch = []
349
+ extrinsics_batch = []
350
+ post_trans_batch = []
351
+ post_rots_batch = []
352
+ for i in range(len(batch)):
353
+ ego_dict = batch[i]["ego"]["image_inputs"]
354
+ imgs_batch.append(ego_dict["imgs"])
355
+ rots_batch.append(ego_dict["rots"])
356
+ trans_batch.append(ego_dict["trans"])
357
+ intrins_batch.append(ego_dict["intrins"])
358
+ extrinsics_batch.append(ego_dict["extrinsics"])
359
+ post_trans_batch.append(ego_dict["post_trans"])
360
+ post_rots_batch.append(ego_dict["post_rots"])
361
+
362
+ output_dict["ego"].update({
363
+ "image_inputs":
364
+ {
365
+ "imgs": torch.stack(imgs_batch), # [B, N, C, H, W]
366
+ "rots": torch.stack(rots_batch),
367
+ "trans": torch.stack(trans_batch),
368
+ "intrins": torch.stack(intrins_batch),
369
+ "post_trans": torch.stack(post_trans_batch),
370
+ "post_rots": torch.stack(post_rots_batch),
371
+ }
372
+ }
373
+ )
374
+
375
+
376
+ return output_dict
377
+
378
+ def collate_batch_test(self, batch):
379
+ """
380
+ Customized collate function for pytorch dataloader during testing
381
+ for late fusion dataset.
382
+
383
+ Parameters
384
+ ----------
385
+ batch : dict
386
+
387
+ Returns
388
+ -------
389
+ batch : dict
390
+ Reformatted batch.
391
+ """
392
+ # currently, we only support batch size of 1 during testing
393
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
394
+ batch = batch[0]
395
+
396
+ output_dict = {}
397
+
398
+ # heterogeneous
399
+ if self.heterogeneous:
400
+ idx = batch['ego']['idx']
401
+ cav_list = batch['ego']['cav_list'] # ['ego', '650' ..]
402
+ cav_num = len(batch)
403
+ lidar_agent, camera_agent = self.selector.select_agent(idx)
404
+ lidar_agent = lidar_agent[:cav_num] # [1,0,0,1,0]
405
+ lidar_agent_idx = lidar_agent.nonzero()[0].tolist()
406
+ lidar_agent_cav_id = [cav_list[index] for index in lidar_agent_idx] # ['ego', ...]
407
+
408
+
409
+ # for late fusion, we also need to stack the lidar for better
410
+ # visualization
411
+ if self.visualize:
412
+ projected_lidar_list = []
413
+ origin_lidar = []
414
+
415
+ for cav_id, cav_content in batch.items():
416
+ output_dict.update({cav_id: {}})
417
+ # shape: (1, max_num, 7)
418
+ object_bbx_center = \
419
+ torch.from_numpy(np.array([cav_content['object_bbx_center']]))
420
+ object_bbx_mask = \
421
+ torch.from_numpy(np.array([cav_content['object_bbx_mask']]))
422
+ object_ids = cav_content['object_ids']
423
+
424
+ # the anchor box is the same for all bounding boxes usually, thus
425
+ # we don't need the batch dimension.
426
+ output_dict[cav_id].update(
427
+ {"anchor_box": self.anchor_box_torch}
428
+ )
429
+
430
+ transformation_matrix = cav_content['transformation_matrix']
431
+ if self.visualize:
432
+ origin_lidar = [cav_content['origin_lidar']]
433
+ if (self.params['only_vis_ego'] is False) or (cav_id=='ego'):
434
+ projected_lidar = copy.deepcopy(cav_content['origin_lidar'])
435
+ projected_lidar[:, :3] = \
436
+ box_utils.project_points_by_matrix_torch(
437
+ projected_lidar[:, :3],
438
+ transformation_matrix)
439
+ projected_lidar_list.append(projected_lidar)
440
+
441
+ if self.load_lidar_file:
442
+ # processed lidar dictionary
443
+ processed_lidar_torch_dict = \
444
+ self.pre_processor.collate_batch(
445
+ [cav_content['processed_lidar']])
446
+ output_dict[cav_id].update({'processed_lidar': processed_lidar_torch_dict})
447
+
448
+ if self.load_camera_file:
449
+ imgs_batch = [cav_content["image_inputs"]["imgs"]]
450
+ rots_batch = [cav_content["image_inputs"]["rots"]]
451
+ trans_batch = [cav_content["image_inputs"]["trans"]]
452
+ intrins_batch = [cav_content["image_inputs"]["intrins"]]
453
+ extrinsics_batch = [cav_content["image_inputs"]["extrinsics"]]
454
+ post_trans_batch = [cav_content["image_inputs"]["post_trans"]]
455
+ post_rots_batch = [cav_content["image_inputs"]["post_rots"]]
456
+
457
+ output_dict[cav_id].update({
458
+ "image_inputs":
459
+ {
460
+ "imgs": torch.stack(imgs_batch),
461
+ "rots": torch.stack(rots_batch),
462
+ "trans": torch.stack(trans_batch),
463
+ "intrins": torch.stack(intrins_batch),
464
+ "extrinsics": torch.stack(extrinsics_batch),
465
+ "post_trans": torch.stack(post_trans_batch),
466
+ "post_rots": torch.stack(post_rots_batch),
467
+ }
468
+ }
469
+ )
470
+
471
+ # heterogeneous
472
+ if self.heterogeneous:
473
+ if cav_id in lidar_agent_cav_id:
474
+ output_dict[cav_id].pop('image_inputs')
475
+ else:
476
+ output_dict[cav_id].pop('processed_lidar')
477
+
478
+ # label dictionary
479
+ label_torch_dict = \
480
+ self.post_processor.collate_batch([cav_content['label_dict']])
481
+
482
+ # for centerpoint
483
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
484
+ 'object_bbx_mask': object_bbx_mask})
485
+
486
+ # save the transformation matrix (4, 4) to ego vehicle
487
+ transformation_matrix_torch = \
488
+ torch.from_numpy(
489
+ np.array(cav_content['transformation_matrix'])).float()
490
+
491
+ # late fusion training, no noise
492
+ transformation_matrix_clean_torch = \
493
+ torch.from_numpy(
494
+ np.array(cav_content['transformation_matrix_clean'])).float()
495
+
496
+ output_dict[cav_id].update({'object_bbx_center': object_bbx_center,
497
+ 'object_bbx_mask': object_bbx_mask,
498
+ 'label_dict': label_torch_dict,
499
+ 'object_ids': object_ids,
500
+ 'transformation_matrix': transformation_matrix_torch,
501
+ 'transformation_matrix_clean': transformation_matrix_clean_torch})
502
+
503
+ if self.visualize:
504
+ origin_lidar = \
505
+ np.array(
506
+ downsample_lidar_minimum(pcd_np_list=origin_lidar))
507
+ origin_lidar = torch.from_numpy(origin_lidar)
508
+ output_dict[cav_id].update({'origin_lidar': origin_lidar})
509
+
510
+ if self.visualize:
511
+ projected_lidar_stack = [torch.from_numpy(
512
+ np.vstack(projected_lidar_list))]
513
+ output_dict['ego'].update({'origin_lidar': projected_lidar_stack})
514
+ # output_dict['ego'].update({'projected_lidar_list': projected_lidar_list})
515
+
516
+ return output_dict
517
+
518
+
519
+ def post_process(self, data_dict, output_dict):
520
+ """
521
+ Process the outputs of the model to 2D/3D bounding box.
522
+
523
+ Parameters
524
+ ----------
525
+ data_dict : dict
526
+ The dictionary containing the origin input data of model.
527
+
528
+ output_dict :dict
529
+ The dictionary containing the output of the model.
530
+
531
+ Returns
532
+ -------
533
+ pred_box_tensor : torch.Tensor
534
+ The tensor of prediction bounding box after NMS.
535
+ gt_box_tensor : torch.Tensor
536
+ The tensor of gt bounding box.
537
+ """
538
+ pred_box_tensor, pred_score = self.post_processor.post_process(
539
+ data_dict, output_dict
540
+ )
541
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
542
+
543
+ return pred_box_tensor, pred_score, gt_box_tensor
544
+
545
+ def post_process_no_fusion(self, data_dict, output_dict_ego):
546
+ data_dict_ego = OrderedDict()
547
+ data_dict_ego["ego"] = data_dict["ego"]
548
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
549
+
550
+ pred_box_tensor, pred_score = self.post_processor.post_process(
551
+ data_dict_ego, output_dict_ego
552
+ )
553
+ return pred_box_tensor, pred_score, gt_box_tensor
554
+
555
+ def post_process_no_fusion_uncertainty(self, data_dict, output_dict_ego):
556
+ data_dict_ego = OrderedDict()
557
+ data_dict_ego['ego'] = data_dict['ego']
558
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
559
+
560
+ pred_box_tensor, pred_score, uncertainty = \
561
+ self.post_processor.post_process(data_dict_ego, output_dict_ego, return_uncertainty=True)
562
+ return pred_box_tensor, pred_score, gt_box_tensor, uncertainty
563
+
564
+ return LateFusionDataset
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/late_heter_fusion_dataset.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # late fusion dataset
2
+ import random
3
+ import math
4
+ from collections import OrderedDict
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import copy
9
+ from icecream import ic
10
+ from PIL import Image
11
+ import pickle as pkl
12
+ from opencood.utils import box_utils as box_utils
13
+ from opencood.data_utils.pre_processor import build_preprocessor
14
+ from opencood.data_utils.post_processor import build_postprocessor
15
+ from opencood.utils.camera_utils import (
16
+ sample_augmentation,
17
+ img_transform,
18
+ normalize_img,
19
+ img_to_tensor,
20
+ )
21
+ from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
22
+ from opencood.utils.transformation_utils import x1_to_x2
23
+ from opencood.utils.pose_utils import add_noise_data_dict
24
+ from opencood.utils.pcd_utils import (
25
+ mask_points_by_range,
26
+ mask_ego_points,
27
+ shuffle_points,
28
+ downsample_lidar_minimum,
29
+ )
30
+ from opencood.utils.common_utils import read_json
31
+ from opencood.utils.common_utils import merge_features_to_dict
32
+ from opencood.utils.heter_utils import Adaptor
33
+
34
+ def getLateheterFusionDataset(cls):
35
+ """
36
+ cls: the Basedataset.
37
+ """
38
+ class LateheterFusionDataset(cls):
39
+ def __init__(self, params, visualize, train=True):
40
+ super().__init__(params, visualize, train)
41
+ self.anchor_box = self.post_processor.generate_anchor_box()
42
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
43
+
44
+ self.heterogeneous = True
45
+ self.modality_assignment = read_json(params['heter']['assignment_path'])
46
+ self.ego_modality = params['heter']['ego_modality'] # "m1" or "m1&m2" or "m3"
47
+
48
+ self.modality_name_list = list(params['heter']['modality_setting'].keys())
49
+ self.sensor_type_dict = OrderedDict()
50
+
51
+ lidar_channels_dict = params['heter'].get('lidar_channels_dict', OrderedDict())
52
+ mapping_dict = params['heter']['mapping_dict']
53
+
54
+ self.adaptor = Adaptor(self.ego_modality,
55
+ self.modality_name_list,
56
+ self.modality_assignment,
57
+ lidar_channels_dict,
58
+ mapping_dict,
59
+ None,
60
+ train)
61
+
62
+ for modality_name, modal_setting in params['heter']['modality_setting'].items():
63
+ self.sensor_type_dict[modality_name] = modal_setting['sensor_type']
64
+ if modal_setting['sensor_type'] == 'lidar':
65
+ setattr(self, f"pre_processor_{modality_name}", build_preprocessor(modal_setting['preprocess'], train))
66
+
67
+ elif modal_setting['sensor_type'] == 'camera':
68
+ setattr(self, f"data_aug_conf_{modality_name}", modal_setting['data_aug_conf'])
69
+
70
+ else:
71
+ raise("Not support this type of sensor")
72
+
73
+ self.reinitialize()
74
+
75
+ def __getitem__(self, idx):
76
+ base_data_dict = self.retrieve_base_data(idx)
77
+ if self.train:
78
+ reformat_data_dict = self.get_item_train(base_data_dict)
79
+ else:
80
+ reformat_data_dict = self.get_item_test(base_data_dict, idx)
81
+ return reformat_data_dict
82
+
83
+ def get_item_train(self, base_data_dict):
84
+ processed_data_dict = OrderedDict()
85
+ base_data_dict = add_noise_data_dict(
86
+ base_data_dict, self.params["noise_setting"]
87
+ )
88
+ # during training, we return a random cav's data
89
+ # only one vehicle is in processed_data_dict
90
+ if not self.visualize:
91
+ options = []
92
+ for cav_id, cav_content in base_data_dict.items():
93
+ if cav_content['modality_name'] in self.ego_modality:
94
+ options.append(cav_id)
95
+ selected_cav_base = base_data_dict[random.choice(options)]
96
+ else:
97
+ selected_cav_id, selected_cav_base = list(base_data_dict.items())[0]
98
+
99
+ selected_cav_processed = self.get_item_single_car(selected_cav_base)
100
+ processed_data_dict.update({"ego": selected_cav_processed})
101
+
102
+ return processed_data_dict
103
+
104
+
105
+ def get_item_test(self, base_data_dict, idx):
106
+ """
107
+ processed_data_dict.keys() = ['ego', "650", "659", ...]
108
+ """
109
+ base_data_dict = add_noise_data_dict(base_data_dict,self.params['noise_setting'])
110
+
111
+ processed_data_dict = OrderedDict()
112
+ ego_id = -1
113
+ ego_lidar_pose = []
114
+ cav_id_list = []
115
+ lidar_pose_list = []
116
+
117
+ # first find the ego vehicle's lidar pose
118
+ for cav_id, cav_content in base_data_dict.items():
119
+ if cav_content['ego']:
120
+ ego_id = cav_id
121
+ ego_lidar_pose = cav_content['params']['lidar_pose']
122
+ ego_lidar_pose_clean = cav_content['params']['lidar_pose_clean']
123
+ break
124
+
125
+ assert ego_id != -1
126
+ assert len(ego_lidar_pose) > 0
127
+
128
+ # loop over all CAVs to process information
129
+ for cav_id, selected_cav_base in base_data_dict.items():
130
+ distance = \
131
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
132
+ ego_lidar_pose[0]) ** 2 + (
133
+ selected_cav_base['params'][
134
+ 'lidar_pose'][1] - ego_lidar_pose[
135
+ 1]) ** 2)
136
+ if distance > self.params['comm_range']:
137
+ continue
138
+
139
+ if self.adaptor.unmatched_modality(selected_cav_base['modality_name']):
140
+ continue
141
+
142
+ cav_id_list.append(cav_id)
143
+ lidar_pose_list.append(selected_cav_base['params']['lidar_pose'])
144
+
145
+ cav_id_list_newname = []
146
+ for cav_id in cav_id_list:
147
+ selected_cav_base = base_data_dict[cav_id]
148
+ # find the transformation matrix from current cav to ego.
149
+ cav_lidar_pose = selected_cav_base['params']['lidar_pose']
150
+ transformation_matrix = x1_to_x2(cav_lidar_pose, ego_lidar_pose)
151
+ cav_lidar_pose_clean = selected_cav_base['params']['lidar_pose_clean']
152
+ transformation_matrix_clean = x1_to_x2(cav_lidar_pose_clean, ego_lidar_pose_clean)
153
+
154
+ # In test phase, we all use lidar label for fair comparison. (need discussion)
155
+ self.label_type = 'lidar' # DAIRV2X
156
+ self.generate_object_center = self.generate_object_center_lidar # OPV2V, V2XSET
157
+
158
+ selected_cav_processed = \
159
+ self.get_item_single_car(selected_cav_base)
160
+ selected_cav_processed.update({'transformation_matrix': transformation_matrix,
161
+ 'transformation_matrix_clean': transformation_matrix_clean})
162
+ update_cav = "ego" if cav_id == ego_id else cav_id
163
+ processed_data_dict.update({update_cav: selected_cav_processed})
164
+ cav_id_list_newname.append(update_cav)
165
+
166
+
167
+ return processed_data_dict
168
+
169
+
170
+ def get_item_single_car(self, selected_cav_base):
171
+ """
172
+ Process a single CAV's information for the train/test pipeline.
173
+
174
+
175
+ Parameters
176
+ ----------
177
+ selected_cav_base : dict
178
+ The dictionary contains a single CAV's raw information.
179
+ including 'params', 'camera_data'
180
+
181
+ Returns
182
+ -------
183
+ selected_cav_processed : dict
184
+ The dictionary contains the cav's processed information.
185
+ """
186
+ selected_cav_processed = {}
187
+ modality_name = selected_cav_base['modality_name']
188
+ sensor_type = self.sensor_type_dict[modality_name]
189
+
190
+ # label
191
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center_single(
192
+ [selected_cav_base], selected_cav_base["params"]["lidar_pose_clean"]
193
+ )
194
+
195
+ # lidar
196
+ if sensor_type == "lidar" or self.visualize:
197
+ lidar_np = selected_cav_base['lidar_np']
198
+ lidar_np = shuffle_points(lidar_np)
199
+ lidar_np = mask_points_by_range(lidar_np,
200
+ self.params['preprocess'][
201
+ 'cav_lidar_range'])
202
+ # remove points that hit ego vehicle
203
+ lidar_np = mask_ego_points(lidar_np)
204
+
205
+ # data augmentation, seems very important for single agent training, because lack of data diversity.
206
+ # only work for lidar modality in training.
207
+ lidar_np, object_bbx_center, object_bbx_mask = \
208
+ self.augment(lidar_np, object_bbx_center, object_bbx_mask)
209
+ if sensor_type == "lidar":
210
+ processed_lidar = eval(f"self.pre_processor_{modality_name}").preprocess(lidar_np)
211
+ selected_cav_processed.update({f'processed_features_{modality_name}': processed_lidar})
212
+
213
+
214
+ if self.visualize:
215
+ selected_cav_processed.update({'origin_lidar': lidar_np})
216
+
217
+ # camera
218
+ if sensor_type == "camera":
219
+ # adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/data.py
220
+ camera_data_list = selected_cav_base["camera_data"]
221
+
222
+ params = selected_cav_base["params"]
223
+ imgs = []
224
+ rots = []
225
+ trans = []
226
+ intrins = []
227
+ extrinsics = [] # cam_to_lidar
228
+ post_rots = []
229
+ post_trans = []
230
+
231
+ for idx, img in enumerate(camera_data_list):
232
+ camera_to_lidar, camera_intrinsic = self.get_ext_int(params, idx)
233
+
234
+ intrin = torch.from_numpy(camera_intrinsic)
235
+ rot = torch.from_numpy(
236
+ camera_to_lidar[:3, :3]
237
+ ) # R_wc, we consider world-coord is the lidar-coord
238
+ tran = torch.from_numpy(camera_to_lidar[:3, 3]) # T_wc
239
+
240
+ post_rot = torch.eye(2)
241
+ post_tran = torch.zeros(2)
242
+
243
+ img_src = [img]
244
+
245
+ # depth
246
+ if self.load_depth_file:
247
+ depth_img = selected_cav_base["depth_data"][idx]
248
+ img_src.append(depth_img)
249
+ else:
250
+ depth_img = None
251
+
252
+ # data augmentation
253
+ resize, resize_dims, crop, flip, rotate = sample_augmentation(
254
+ eval(f"self.data_aug_conf_{modality_name}"), self.train
255
+ )
256
+ img_src, post_rot2, post_tran2 = img_transform(
257
+ img_src,
258
+ post_rot,
259
+ post_tran,
260
+ resize=resize,
261
+ resize_dims=resize_dims,
262
+ crop=crop,
263
+ flip=flip,
264
+ rotate=rotate,
265
+ )
266
+ # for convenience, make augmentation matrices 3x3
267
+ post_tran = torch.zeros(3)
268
+ post_rot = torch.eye(3)
269
+ post_tran[:2] = post_tran2
270
+ post_rot[:2, :2] = post_rot2
271
+
272
+ img_src[0] = normalize_img(img_src[0])
273
+ if self.load_depth_file:
274
+ img_src[1] = img_to_tensor(img_src[1]) * 255
275
+
276
+ imgs.append(torch.cat(img_src, dim=0))
277
+ intrins.append(intrin)
278
+ extrinsics.append(torch.from_numpy(camera_to_lidar))
279
+ rots.append(rot)
280
+ trans.append(tran)
281
+ post_rots.append(post_rot)
282
+ post_trans.append(post_tran)
283
+
284
+ selected_cav_processed.update(
285
+ {
286
+ f"image_inputs_{modality_name}":
287
+ {
288
+ "imgs": torch.stack(imgs), # [N, 3or4, H, W]
289
+ "intrins": torch.stack(intrins),
290
+ "extrinsics": torch.stack(extrinsics),
291
+ "rots": torch.stack(rots),
292
+ "trans": torch.stack(trans),
293
+ "post_rots": torch.stack(post_rots),
294
+ "post_trans": torch.stack(post_trans),
295
+ }
296
+ }
297
+ )
298
+
299
+
300
+ selected_cav_processed.update(
301
+ {
302
+ "object_bbx_center": object_bbx_center,
303
+ "object_bbx_mask": object_bbx_mask,
304
+ "object_ids": object_ids,
305
+ "modality_name": modality_name
306
+ }
307
+ )
308
+
309
+ # generate targets label
310
+ label_dict = self.post_processor.generate_label(
311
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
312
+ )
313
+ selected_cav_processed.update({"label_dict": label_dict})
314
+
315
+ return selected_cav_processed
316
+
317
+
318
+ def collate_batch_train(self, batch):
319
+ """
320
+ Customized collate function for pytorch dataloader during training
321
+ for early and late fusion dataset.
322
+
323
+ Parameters
324
+ ----------
325
+ batch : dict
326
+
327
+ Returns
328
+ -------
329
+ batch : dict
330
+ Reformatted batch.
331
+ """
332
+ # during training, we only care about ego.
333
+ output_dict = {'ego': {}}
334
+
335
+ object_bbx_center = []
336
+ object_bbx_mask = []
337
+ label_dict_list = []
338
+ origin_lidar = []
339
+ inputs_list_m1 = []
340
+ inputs_list_m2 = []
341
+ inputs_list_m3 = []
342
+ inputs_list_m4 = []
343
+ for i in range(len(batch)):
344
+ ego_dict = batch[i]['ego']
345
+ object_bbx_center.append(ego_dict['object_bbx_center'])
346
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
347
+ label_dict_list.append(ego_dict['label_dict'])
348
+
349
+ if self.visualize:
350
+ origin_lidar.append(ego_dict['origin_lidar'])
351
+
352
+ # convert to numpy, (B, max_num, 7)
353
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
354
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
355
+ label_torch_dict = \
356
+ self.post_processor.collate_batch(label_dict_list)
357
+
358
+ # for centerpoint
359
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
360
+ 'object_bbx_mask': object_bbx_mask})
361
+
362
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
363
+ 'object_bbx_mask': object_bbx_mask,
364
+ 'anchor_box': torch.from_numpy(self.anchor_box),
365
+ 'label_dict': label_torch_dict})
366
+ if self.visualize:
367
+ origin_lidar = \
368
+ np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
369
+ origin_lidar = torch.from_numpy(origin_lidar)
370
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
371
+
372
+
373
+
374
+
375
+ for modality_name in self.modality_name_list:
376
+ sensor_type = self.sensor_type_dict[modality_name]
377
+ for i in range(len(batch)):
378
+ ego_dict = batch[i]['ego']
379
+ if f'processed_features_{modality_name}' in ego_dict:
380
+ eval(f"inputs_list_{modality_name}").append(ego_dict[f'processed_features_{modality_name}'])
381
+ elif f'image_inputs_{modality_name}' in ego_dict:
382
+ eval(f"inputs_list_{modality_name}").append(ego_dict[f'image_inputs_{modality_name}'])
383
+
384
+ if self.sensor_type_dict[modality_name] == "lidar":
385
+ processed_lidar_torch_dict = eval(f"self.pre_processor_{modality_name}").collate_batch(eval(f"inputs_list_{modality_name}"))
386
+ output_dict['ego'].update({f'inputs_{modality_name}': processed_lidar_torch_dict})
387
+ elif self.sensor_type_dict[modality_name] == "camera":
388
+ merged_image_inputs_dict = merge_features_to_dict(eval(f"inputs_list_{modality_name}"), merge='stack')
389
+ output_dict['ego'].update({f'inputs_{modality_name}': merged_image_inputs_dict})
390
+
391
+ return output_dict
392
+
393
+ def collate_batch_test(self, batch):
394
+ """
395
+ Customized collate function for pytorch dataloader during testing
396
+ for late fusion dataset.
397
+
398
+ Parameters
399
+ ----------
400
+ batch : dict
401
+
402
+ Returns
403
+ -------
404
+ batch : dict
405
+ Reformatted batch.
406
+ """
407
+ # currently, we only support batch size of 1 during testing
408
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
409
+ batch = batch[0]
410
+
411
+ output_dict = {}
412
+
413
+ # for late fusion, we also need to stack the lidar for better
414
+ # visualization
415
+ if self.visualize:
416
+ projected_lidar_list = []
417
+ origin_lidar = []
418
+
419
+ for cav_id, cav_content in batch.items():
420
+ modality_name = cav_content['modality_name']
421
+ sensor_type = self.sensor_type_dict[modality_name]
422
+
423
+ output_dict.update({cav_id: {}})
424
+ # shape: (1, max_num, 7)
425
+ object_bbx_center = \
426
+ torch.from_numpy(np.array([cav_content['object_bbx_center']]))
427
+ object_bbx_mask = \
428
+ torch.from_numpy(np.array([cav_content['object_bbx_mask']]))
429
+ object_ids = cav_content['object_ids']
430
+
431
+ # the anchor box is the same for all bounding boxes usually, thus
432
+ # we don't need the batch dimension.
433
+ output_dict[cav_id].update(
434
+ {"anchor_box": self.anchor_box_torch}
435
+ )
436
+
437
+ transformation_matrix = cav_content['transformation_matrix']
438
+ if self.visualize:
439
+ origin_lidar = [cav_content['origin_lidar']]
440
+ if (self.params.get('only_vis_ego', True) is False) or (cav_id=='ego'):
441
+ projected_lidar = copy.deepcopy(cav_content['origin_lidar'])
442
+ projected_lidar[:, :3] = \
443
+ box_utils.project_points_by_matrix_torch(
444
+ projected_lidar[:, :3],
445
+ transformation_matrix)
446
+ projected_lidar_list.append(projected_lidar)
447
+
448
+ if sensor_type == "lidar":
449
+ # processed lidar dictionary
450
+ processed_lidar_torch_dict = \
451
+ eval(f"self.pre_processor_{modality_name}").collate_batch([cav_content[f'processed_features_{modality_name}']])
452
+ output_dict[cav_id].update({f'inputs_{modality_name}': processed_lidar_torch_dict})
453
+
454
+ if sensor_type == 'camera':
455
+ imgs_batch = [cav_content[f"image_inputs_{modality_name}"]["imgs"]]
456
+ rots_batch = [cav_content[f"image_inputs_{modality_name}"]["rots"]]
457
+ trans_batch = [cav_content[f"image_inputs_{modality_name}"]["trans"]]
458
+ intrins_batch = [cav_content[f"image_inputs_{modality_name}"]["intrins"]]
459
+ extrinsics_batch = [cav_content[f"image_inputs_{modality_name}"]["extrinsics"]]
460
+ post_trans_batch = [cav_content[f"image_inputs_{modality_name}"]["post_trans"]]
461
+ post_rots_batch = [cav_content[f"image_inputs_{modality_name}"]["post_rots"]]
462
+
463
+ output_dict[cav_id].update({
464
+ f"inputs_{modality_name}":
465
+ {
466
+ "imgs": torch.stack(imgs_batch),
467
+ "rots": torch.stack(rots_batch),
468
+ "trans": torch.stack(trans_batch),
469
+ "intrins": torch.stack(intrins_batch),
470
+ "extrinsics": torch.stack(extrinsics_batch),
471
+ "post_trans": torch.stack(post_trans_batch),
472
+ "post_rots": torch.stack(post_rots_batch),
473
+ }
474
+ }
475
+ )
476
+
477
+
478
+ # label dictionary
479
+ label_torch_dict = \
480
+ self.post_processor.collate_batch([cav_content['label_dict']])
481
+
482
+ # for centerpoint
483
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
484
+ 'object_bbx_mask': object_bbx_mask})
485
+
486
+ # save the transformation matrix (4, 4) to ego vehicle
487
+ transformation_matrix_torch = \
488
+ torch.from_numpy(
489
+ np.array(cav_content['transformation_matrix'])).float()
490
+
491
+ # late fusion training, no noise
492
+ transformation_matrix_clean_torch = \
493
+ torch.from_numpy(
494
+ np.array(cav_content['transformation_matrix_clean'])).float()
495
+
496
+ output_dict[cav_id].update({'object_bbx_center': object_bbx_center,
497
+ 'object_bbx_mask': object_bbx_mask,
498
+ 'label_dict': label_torch_dict,
499
+ 'object_ids': object_ids,
500
+ 'transformation_matrix': transformation_matrix_torch,
501
+ 'transformation_matrix_clean': transformation_matrix_clean_torch,
502
+ 'modality_name': modality_name})
503
+
504
+ if self.visualize:
505
+ origin_lidar = \
506
+ np.array(
507
+ downsample_lidar_minimum(pcd_np_list=origin_lidar))
508
+ origin_lidar = torch.from_numpy(origin_lidar)
509
+ output_dict[cav_id].update({'origin_lidar': origin_lidar})
510
+
511
+ if self.visualize:
512
+ projected_lidar_stack = [torch.from_numpy(
513
+ np.vstack(projected_lidar_list))]
514
+ output_dict['ego'].update({'origin_lidar': projected_lidar_stack})
515
+ # output_dict['ego'].update({'projected_lidar_list': projected_lidar_list})
516
+
517
+ return output_dict
518
+
519
+
520
+ def post_process(self, data_dict, output_dict):
521
+ """
522
+ Process the outputs of the model to 2D/3D bounding box.
523
+
524
+ Parameters
525
+ ----------
526
+ data_dict : dict
527
+ The dictionary containing the origin input data of model.
528
+
529
+ output_dict :dict
530
+ The dictionary containing the output of the model.
531
+
532
+ Returns
533
+ -------
534
+ pred_box_tensor : torch.Tensor
535
+ The tensor of prediction bounding box after NMS.
536
+ gt_box_tensor : torch.Tensor
537
+ The tensor of gt bounding box.
538
+ """
539
+ pred_box_tensor, pred_score = self.post_processor.post_process(
540
+ data_dict, output_dict
541
+ )
542
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
543
+
544
+ return pred_box_tensor, pred_score, gt_box_tensor
545
+
546
+ def post_process_no_fusion(self, data_dict, output_dict_ego):
547
+ data_dict_ego = OrderedDict()
548
+ data_dict_ego["ego"] = data_dict["ego"]
549
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
550
+
551
+ pred_box_tensor, pred_score = self.post_processor.post_process(
552
+ data_dict_ego, output_dict_ego
553
+ )
554
+ return pred_box_tensor, pred_score, gt_box_tensor
555
+
556
+ def post_process_no_fusion_uncertainty(self, data_dict, output_dict_ego):
557
+ data_dict_ego = OrderedDict()
558
+ data_dict_ego['ego'] = data_dict['ego']
559
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
560
+
561
+ pred_box_tensor, pred_score, uncertainty = \
562
+ self.post_processor.post_process(data_dict_ego, output_dict_ego, return_uncertainty=True)
563
+ return pred_box_tensor, pred_score, gt_box_tensor, uncertainty
564
+
565
+ return LateheterFusionDataset
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/late_multi_fusion_dataset.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # late fusion dataset
2
+ import random
3
+ import math
4
+ from collections import OrderedDict
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import copy
9
+ from icecream import ic
10
+ from PIL import Image
11
+ import pickle as pkl
12
+ from opencood.utils import box_utils as box_utils
13
+ from opencood.data_utils.pre_processor import build_preprocessor
14
+ from opencood.data_utils.post_processor import build_postprocessor
15
+ from opencood.utils.camera_utils import (
16
+ sample_augmentation,
17
+ img_transform,
18
+ normalize_img,
19
+ img_to_tensor,
20
+ )
21
+ from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
22
+ from opencood.utils.transformation_utils import x1_to_x2
23
+ from opencood.utils.pose_utils import add_noise_data_dict
24
+ from opencood.utils.pcd_utils import (
25
+ mask_points_by_range,
26
+ mask_ego_points,
27
+ shuffle_points,
28
+ downsample_lidar_minimum,
29
+ )
30
+
31
+
32
+
33
+ def getLateclassFusionDataset(cls):
34
+ """
35
+ cls: the BaseDataset or父类数据集, 负责一些基础接口,如:
36
+ - retrieve_base_data()
37
+ - generate_object_center_single()
38
+ - self.post_processor
39
+ - self.pre_processor
40
+ - self.selector (如果用了 heterogeneous 配置)
41
+ 等等
42
+ """
43
+ class LateclassFusionDataset(cls):
44
+ def __init__(self, params, visualize, train=True):
45
+ super().__init__(params, visualize, train)
46
+ self.anchor_box = self.post_processor.generate_anchor_box()
47
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
48
+
49
+ # 是否启用异构学习(例如只选择某些Agent用lidar,某些Agent用camera)
50
+ self.heterogeneous = False
51
+ if "heter" in params:
52
+ self.heterogeneous = True
53
+
54
+ # 是否为多类别
55
+ self.multiclass = params["model"]["args"].get("multi_class", False)
56
+
57
+ # 根据需要,可在这里给定多类别的类别 ID 列表
58
+ # 比如 [0, 1, 3] 分别对应 car / pedestrian / cyclist 等
59
+ self.class_list = params.get("class_list", [0, 1, 3])
60
+ # 若项目里您是通过 [ 'all', 0, 1, 3 ] 这种方式区分,也可自行调整
61
+
62
+ # 用于可视化
63
+ self.visualize = visualize
64
+ self.train = train
65
+
66
+ def __getitem__(self, idx):
67
+ """
68
+ 训练阶段:随机选 1 个 CAV 做 late 监督(与LateFusionDataset一致);
69
+ 测试/验证阶段:保留所有范围内 CAV 的信息。
70
+ """
71
+ base_data_dict = self.retrieve_base_data(idx)
72
+ if self.train:
73
+ reformat_data_dict = self.get_item_train(base_data_dict)
74
+ else:
75
+ reformat_data_dict = self.get_item_test(base_data_dict, idx)
76
+ return reformat_data_dict
77
+
78
+ def get_item_train(self, base_data_dict):
79
+ """
80
+ 训练阶段的处理逻辑:通常是只抽取 1 个 CAV(含有 label),
81
+ 以减少内存开销、保持与单车训练类似。
82
+ """
83
+ from collections import OrderedDict
84
+ processed_data_dict = OrderedDict()
85
+
86
+ # 数据扰动(如果有)
87
+ base_data_dict = self.add_noise_data_if_needed(base_data_dict)
88
+
89
+ # 只随机抽取一个 CAV
90
+ if not self.visualize:
91
+ selected_cav_id, selected_cav_base = random.choice(
92
+ list(base_data_dict.items())
93
+ )
94
+ else:
95
+ # 若要可视化,通常选 ego 做可视化
96
+ selected_cav_id, selected_cav_base = list(base_data_dict.items())[0]
97
+
98
+ # 处理单个车辆(含多类别的 bbox)
99
+ cav_processed = self.get_item_single_car(selected_cav_base)
100
+ processed_data_dict["ego"] = cav_processed
101
+ return processed_data_dict
102
+
103
+ def get_item_test(self, base_data_dict, idx):
104
+ """
105
+ 测试/验证阶段:保留所有在 comm_range 内的 CAV,都要 late fusion 的 label。
106
+ """
107
+ from collections import OrderedDict
108
+ import math
109
+
110
+ base_data_dict = self.add_noise_data_if_needed(base_data_dict)
111
+
112
+ processed_data_dict = OrderedDict()
113
+ ego_id, ego_pose = -1, None
114
+ # 首先找到 ego
115
+ for cav_id, cav_content in base_data_dict.items():
116
+ if cav_content["ego"]:
117
+ ego_id = cav_id
118
+ ego_pose = cav_content["params"]["lidar_pose"]
119
+ ego_pose_clean = cav_content["params"]["lidar_pose_clean"]
120
+ break
121
+ assert ego_id != -1
122
+
123
+ cav_id_list = []
124
+ for cav_id, cav_content in base_data_dict.items():
125
+ distance = math.sqrt(
126
+ (cav_content["params"]["lidar_pose"][0] - ego_pose[0]) ** 2
127
+ + (cav_content["params"]["lidar_pose"][1] - ego_pose[1]) ** 2
128
+ )
129
+ if distance <= self.params["comm_range"]:
130
+ cav_id_list.append(cav_id)
131
+
132
+ cav_id_list_newname = []
133
+ for cav_id in cav_id_list:
134
+ selected_cav_base = base_data_dict[cav_id]
135
+ transformation_matrix = self.x1_to_x2(
136
+ selected_cav_base["params"]["lidar_pose"], ego_pose
137
+ )
138
+ transformation_matrix_clean = self.x1_to_x2(
139
+ selected_cav_base["params"]["lidar_pose_clean"], ego_pose_clean
140
+ )
141
+ cav_processed = self.get_item_single_car(selected_cav_base)
142
+ cav_processed.update(
143
+ {
144
+ "transformation_matrix": transformation_matrix,
145
+ "transformation_matrix_clean": transformation_matrix_clean,
146
+ }
147
+ )
148
+ # 若是 ego 自身,就命名为 "ego",否则保持 cav_id
149
+ update_cav_key = "ego" if cav_id == ego_id else cav_id
150
+ processed_data_dict[update_cav_key] = cav_processed
151
+ cav_id_list_newname.append(update_cav_key)
152
+
153
+ # heterogeneous 额外信息
154
+ if self.heterogeneous:
155
+ processed_data_dict["ego"]["idx"] = idx
156
+ processed_data_dict["ego"]["cav_list"] = cav_id_list_newname
157
+
158
+ return processed_data_dict
159
+
160
+ def get_item_single_car(self, cav_base):
161
+ """
162
+ 处理单辆车的信息,生成其多类别的 label、lidar 数据、camera 数据等等。
163
+ """
164
+ selected_cav_processed = {}
165
+
166
+ # 1) 生成多类别或单类别目标框
167
+ # 如果多类别,就将 cav_base 中属于各类的目标框分开存储/或一次性存 [num_class, max_box, 7]
168
+ if self.multiclass:
169
+ # 举例:将 class_list = [0,1,3] 三个类别分别解析
170
+ # 最简单做法是:对 cav_base["params"]["lidar_pose_clean"] 调用多次 generate_object_center_single
171
+ # 并把结果堆叠
172
+ all_box_list, all_mask_list, all_ids_list = [], [], []
173
+ for cls_id in self.class_list:
174
+ box_c, mask_c, ids_c = self.generate_object_center_single(
175
+ [cav_base],
176
+ cav_base["params"]["lidar_pose_clean"],
177
+ class_type=cls_id, # 您可在 generate_object_center_single 里根据 class_type 做过滤
178
+ )
179
+ all_box_list.append(box_c)
180
+ all_mask_list.append(mask_c)
181
+ all_ids_list.append(ids_c)
182
+
183
+ # 堆叠成 [num_class, max_box, 7] / [num_class, max_box]
184
+ # 需注意每次 generate_object_center_single 返回的 max_box 数量可能不同,
185
+ # 这里需统一补零或 slice 到相同维度(可参考已有Late/IntermediateFusion实现).
186
+ object_bbx_center, object_bbx_mask = self.stack_multiclass_label(
187
+ all_box_list, all_mask_list
188
+ )
189
+ # object_ids 可以按类别各存一个 list,也可以只存 [num_class, ...]
190
+ object_ids = all_ids_list # 也可做特殊处理
191
+ else:
192
+ # 单类别情况下:直接一次即可
193
+ object_bbx_center, object_bbx_mask, object_ids = (
194
+ self.generate_object_center_single(
195
+ [cav_base], cav_base["params"]["lidar_pose_clean"]
196
+ )
197
+ )
198
+
199
+ # 2) lidar 处理(或 camera)
200
+ # 若需要 lidar,可做 voxelize -> self.pre_processor
201
+ if self.load_lidar_file or self.visualize:
202
+ lidar_np = cav_base["lidar_np"]
203
+ # 一些基础处理,如 shuffle_points, mask_points_by_range, mask_ego_points 等
204
+ lidar_np = self.basic_lidar_preprocess(lidar_np)
205
+ # 数据增强(根据需要)
206
+ lidar_np, object_bbx_center, object_bbx_mask = self.augment_if_needed(
207
+ lidar_np, object_bbx_center, object_bbx_mask
208
+ )
209
+ # 真正处理,如 voxelize/BEV projection
210
+ processed_lidar = self.pre_processor.preprocess(lidar_np)
211
+ selected_cav_processed["processed_lidar"] = processed_lidar
212
+
213
+ if self.visualize:
214
+ selected_cav_processed["origin_lidar"] = lidar_np
215
+
216
+ # 3) camera 处理
217
+ if self.load_camera_file:
218
+ # 类似 LateFusionDataset 中的逻辑
219
+ camera_inputs = self.process_camera_data(cav_base)
220
+ selected_cav_processed["image_inputs"] = camera_inputs
221
+
222
+ # 4) 保存多类别框
223
+ selected_cav_processed.update(
224
+ {
225
+ "object_bbx_center": object_bbx_center,
226
+ "object_bbx_mask": object_bbx_mask,
227
+ "object_ids": object_ids,
228
+ }
229
+ )
230
+
231
+ # 5) 生成 label,若多类别则也要多类别 label
232
+ if self.multiclass:
233
+ # 自行封装 post_processor.generate_label(...) 以支持 multi-class
234
+ # 也可对每个类别分别调用
235
+ label_dict = self.post_processor.generate_label_multiclass(
236
+ object_bbx_center, # [num_class, max_box, 7]
237
+ self.anchor_box,
238
+ object_bbx_mask, # [num_class, max_box]
239
+ )
240
+ else:
241
+ label_dict = self.post_processor.generate_label(
242
+ object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
243
+ )
244
+
245
+ selected_cav_processed["label_dict"] = label_dict
246
+ return selected_cav_processed
247
+
248
+ ############################
249
+ # collate_batch 相关处理 #
250
+ ############################
251
+ def collate_batch_train(self, batch):
252
+ """
253
+ 训练集的 collate:
254
+ 由于本示例中 train 阶段只随机取了 1 个 CAV,直接按 batch 拼接即可。
255
+ 若您想要真正多 CAV 的 late 监督训练,则需参考 test collate 的思路。
256
+ """
257
+ import torch
258
+ from collections import OrderedDict
259
+ output_dict = {"ego": {}}
260
+
261
+ object_bbx_center_list = []
262
+ object_bbx_mask_list = []
263
+ label_dict_list = []
264
+ origin_lidar_list = []
265
+
266
+ processed_lidar_list = []
267
+
268
+ for item in batch:
269
+ ego_data = item["ego"]
270
+ object_bbx_center_list.append(ego_data["object_bbx_center"])
271
+ object_bbx_mask_list.append(ego_data["object_bbx_mask"])
272
+ label_dict_list.append(ego_data["label_dict"])
273
+
274
+ if self.visualize and "origin_lidar" in ego_data:
275
+ origin_lidar_list.append(ego_data["origin_lidar"])
276
+
277
+ if "processed_lidar" in ego_data:
278
+ processed_lidar_list.append(ego_data["processed_lidar"])
279
+
280
+ # 转成 tensor
281
+ object_bbx_center_torch = self.list_to_tensor(object_bbx_center_list)
282
+ object_bbx_mask_torch = self.list_to_tensor(object_bbx_mask_list)
283
+
284
+ # 多类别 label 的 collate (或单类别)
285
+ label_torch_dict = self.post_processor.collate_batch(label_dict_list)
286
+ # 若使用 centerpoint, 还要再把 object_bbx_center_torch 等融合进 label_torch_dict
287
+ label_torch_dict.update(
288
+ {
289
+ "object_bbx_center": object_bbx_center_torch,
290
+ "object_bbx_mask": object_bbx_mask_torch,
291
+ }
292
+ )
293
+
294
+ output_dict["ego"].update(
295
+ {
296
+ "object_bbx_center": object_bbx_center_torch,
297
+ "object_bbx_mask": object_bbx_mask_torch,
298
+ "anchor_box": torch.from_numpy(self.anchor_box),
299
+ "label_dict": label_torch_dict,
300
+ }
301
+ )
302
+
303
+ # lidar
304
+ if len(processed_lidar_list) > 0:
305
+ processed_lidar_torch_dict = self.pre_processor.collate_batch(
306
+ processed_lidar_list
307
+ )
308
+ output_dict["ego"]["processed_lidar"] = processed_lidar_torch_dict
309
+
310
+ # camera
311
+ if self.load_camera_file:
312
+ # 类似 LateFusionDataset: 将 batch 里的 camera 信息按维度拼起来
313
+ camera_inputs = self.collate_camera_inputs_train(batch)
314
+ output_dict["ego"]["image_inputs"] = camera_inputs
315
+
316
+ # visualization
317
+ if self.visualize and len(origin_lidar_list) > 0:
318
+ # 您可以根据需要 downsample
319
+ origin_lidar_torch = self.list_to_tensor(origin_lidar_list)
320
+ output_dict["ego"]["origin_lidar"] = origin_lidar_torch
321
+
322
+ return output_dict
323
+
324
+ def collate_batch_test(self, batch):
325
+ """
326
+ 测试集(或验证集)的 collate:
327
+ 一般只支持 batch_size=1(尤其在多 CAV 的情况下),
328
+ 然后把每个 CAV 单独拿出来做 late 处理。
329
+ """
330
+ assert len(batch) == 1, "Test time batch_size must be 1 for late fusion!"
331
+ batch = batch[0]
332
+
333
+ output_dict = {}
334
+ # heterogeneous
335
+ if self.heterogeneous and "idx" in batch["ego"]:
336
+ idx = batch["ego"]["idx"]
337
+ cav_list = batch["ego"]["cav_list"]
338
+ # 选择哪些 cav 用 lidar / camera
339
+ # lidar_agent, camera_agent = self.selector.select_agent(idx)
340
+ # ...
341
+
342
+ # 收集并 collate
343
+ if self.visualize:
344
+ import copy
345
+ projected_lidar_list = []
346
+
347
+ for cav_id, cav_content in batch.items():
348
+ output_dict[cav_id] = {}
349
+ # 把 object_bbx_center/mask 变成 [1, ...]
350
+ object_bbx_center = self.unsqueeze_to_batch(cav_content["object_bbx_center"])
351
+ object_bbx_mask = self.unsqueeze_to_batch(cav_content["object_bbx_mask"])
352
+
353
+ label_dict = self.post_processor.collate_batch([cav_content["label_dict"]])
354
+ # centerpoint 需把 object_bbx_center/mask 再塞回 label_dict
355
+ label_dict.update(
356
+ {
357
+ "object_bbx_center": object_bbx_center,
358
+ "object_bbx_mask": object_bbx_mask,
359
+ }
360
+ )
361
+
362
+ # lidar
363
+ if "processed_lidar" in cav_content:
364
+ # 只有 1 个 cav 的 processed_lidar
365
+ processed_lidar_torch = self.pre_processor.collate_batch(
366
+ [cav_content["processed_lidar"]]
367
+ )
368
+ output_dict[cav_id]["processed_lidar"] = processed_lidar_torch
369
+
370
+ # camera
371
+ if self.load_camera_file and "image_inputs" in cav_content:
372
+ # 同理,只拼一个
373
+ cam_torch = self.collate_camera_inputs_test(cav_content)
374
+ output_dict[cav_id]["image_inputs"] = cam_torch
375
+
376
+ # heterogeneous 可根据 cav_id 判断是否保留/剔除
377
+ # if self.heterogeneous:
378
+ # pass
379
+
380
+ # 保存变换矩阵
381
+ output_dict[cav_id]["transformation_matrix"] = torch.from_numpy(
382
+ cav_content["transformation_matrix"]
383
+ ).float()
384
+ output_dict[cav_id]["transformation_matrix_clean"] = torch.from_numpy(
385
+ cav_content["transformation_matrix_clean"]
386
+ ).float()
387
+
388
+ # label + 其他信息
389
+ output_dict[cav_id].update(
390
+ {
391
+ "object_bbx_center": object_bbx_center,
392
+ "object_bbx_mask": object_bbx_mask,
393
+ "label_dict": label_dict,
394
+ "anchor_box": self.anchor_box_torch,
395
+ "object_ids": cav_content["object_ids"],
396
+ }
397
+ )
398
+
399
+ if self.visualize and "origin_lidar" in cav_content:
400
+ output_dict[cav_id]["origin_lidar"] = torch.from_numpy(
401
+ cav_content["origin_lidar"]
402
+ )
403
+
404
+ # 若需要把多 cav 的点云拼接到 ego 上做可视化,可以在这里做拼接
405
+ return output_dict
406
+
407
+ ######################################
408
+ # 多类别后处理示例 #
409
+ ######################################
410
+ def post_process(self, data_dict, output_dict):
411
+ """
412
+ 如果是多类别,就调用 self.post_process_multiclass,
413
+ 否则与普通 late fusion 相同。
414
+ """
415
+ if self.multiclass:
416
+ # 返回 [List of pred_box], [List of score], [List of gt_box],每个元素对应一个类别
417
+ return self.post_process_multiclass(data_dict, output_dict)
418
+ else:
419
+ pred_box, pred_score = self.post_processor.post_process(data_dict, output_dict)
420
+ gt_box = self.post_processor.generate_gt_bbx(data_dict)
421
+ return pred_box, pred_score, gt_box
422
+
423
+ def post_process_multiclass(self, data_dict, output_dict):
424
+ """
425
+ 多类别的后处理,每个类别各跑一次 NMS 或类似处理,然后拼一起返回。
426
+ """
427
+ import copy
428
+
429
+ # num_class = len(self.class_list)
430
+ pred_box_tensor_list = []
431
+ pred_score_list = []
432
+ gt_box_tensor_list = []
433
+
434
+ # 对每个类别独立后处理
435
+ for i, cls_id in enumerate(self.class_list):
436
+ # 1) 拷贝出仅包含该类别的数据
437
+ data_dict_single, output_dict_single = self.split_single_class(
438
+ data_dict, output_dict, class_index=i
439
+ )
440
+ # 2) 跑后处理
441
+ pred_box_tensor, pred_score = self.post_processor.post_process(
442
+ data_dict_single, output_dict_single
443
+ )
444
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict_single)
445
+
446
+ pred_box_tensor_list.append(pred_box_tensor)
447
+ pred_score_list.append(pred_score)
448
+ gt_box_tensor_list.append(gt_box_tensor)
449
+
450
+ return pred_box_tensor_list, pred_score_list, gt_box_tensor_list
451
+
452
+ ############################################
453
+ # 下方放一些复用/简化方法(根据项目适配即可) #
454
+ ############################################
455
+ def add_noise_data_if_needed(self, base_data_dict):
456
+ """
457
+ 根据 self.params["noise_setting"] 等需求决定是否进行噪声扰动。
458
+ 这里直接调用已有的 add_noise_data_dict 或 add_noise_data_dict_asymmetric。
459
+ """
460
+ from opencood.utils.pose_utils import add_noise_data_dict
461
+ # 如果想用非对称噪声,请自行替换
462
+ return add_noise_data_dict(base_data_dict, self.params["noise_setting"])
463
+
464
+ def basic_lidar_preprocess(self, lidar_np):
465
+ """
466
+ 一些通用的点云预处理,如范围裁剪、shuffle、去除自车点等。
467
+ """
468
+ from opencood.utils.pcd_utils import (
469
+ shuffle_points,
470
+ mask_points_by_range,
471
+ mask_ego_points,
472
+ )
473
+ lidar_np = shuffle_points(lidar_np)
474
+ lidar_np = mask_points_by_range(lidar_np, self.params["preprocess"]["cav_lidar_range"])
475
+ lidar_np = mask_ego_points(lidar_np)
476
+ return lidar_np
477
+
478
+ def augment_if_needed(self, lidar_np, object_bbx_center, object_bbx_mask):
479
+ """
480
+ 若 self.train 并且无需异构,可对点云/标签做数据增强。
481
+ """
482
+ if self.train and not self.heterogeneous:
483
+ lidar_np, object_bbx_center, object_bbx_mask = self.augment(
484
+ lidar_np, object_bbx_center, object_bbx_mask
485
+ )
486
+ return lidar_np, object_bbx_center, object_bbx_mask
487
+
488
+ def process_camera_data(self, cav_base):
489
+ """
490
+ 将相机图像根据参数(分辨率缩放、裁剪、flip 等)做增广,并返回成一个 dict。
491
+ 可参考 LateFusionDataset / LSS 处理流程。
492
+ """
493
+ # 这里仅示例化简, 具体实现请参考原 LateFusionDataset 中的 get_item_single_car -> process_camera_data
494
+ camera_data_list = cav_base["camera_data"]
495
+ # ... 做增广与 transform ...
496
+ camera_inputs = {"imgs": None, "rots": None, ...}
497
+ return camera_inputs
498
+
499
+ def collate_camera_inputs_train(self, batch):
500
+ """
501
+ 将 train batch 里多帧图像按维度拼接,比如 [B, N, C, H, W]
502
+ """
503
+ # 略,参考 LateFusionDataset 的 collate_batch_train
504
+ return {}
505
+
506
+ def collate_camera_inputs_test(self, cav_content):
507
+ """
508
+ 测试阶段只 collate 单个 cav
509
+ """
510
+ # 参考 LateFusionDataset 的 collate_batch_test
511
+ return {}
512
+
513
+ def stack_multiclass_label(self, box_list, mask_list):
514
+ """
515
+ 输入是一个 list,每个元素是 (max_box, 7)/(max_box,),
516
+ 最终拼成 [num_class, max_box, 7] / [num_class, max_box]。
517
+ 若每个类别分配的 max_box 不同,需要先找最大值再做 padding。
518
+ """
519
+ import numpy as np
520
+ num_class = len(box_list)
521
+ max_box_counts = [b.shape[0] for b in box_list]
522
+ M = max(max_box_counts) if max_box_counts else 0
523
+
524
+ # 组合
525
+ box_array = []
526
+ mask_array = []
527
+ for i in range(num_class):
528
+ cur_box = box_list[i]
529
+ cur_mask = mask_list[i]
530
+ pad_size = M - cur_box.shape[0]
531
+ if pad_size > 0:
532
+ # 在 0 处 padding
533
+ cur_box = np.concatenate(
534
+ [cur_box, np.zeros((pad_size, 7), dtype=cur_box.dtype)], axis=0
535
+ )
536
+ cur_mask = np.concatenate(
537
+ [cur_mask, np.zeros(pad_size, dtype=cur_mask.dtype)], axis=0
538
+ )
539
+ box_array.append(cur_box[None, ...]) # [1, M, 7]
540
+ mask_array.append(cur_mask[None, ...]) # [1, M]
541
+
542
+ if len(box_array) == 0:
543
+ # 说明没对象
544
+ return np.zeros((0, 0, 7)), np.zeros((0, 0))
545
+
546
+ box_array = np.concatenate(box_array, axis=0) # [num_class, M, 7]
547
+ mask_array = np.concatenate(mask_array, axis=0) # [num_class, M]
548
+ return box_array, mask_array
549
+
550
+ def split_single_class(self, data_dict, output_dict, class_index):
551
+ """
552
+ post_process_multiclass 用到:
553
+ 将 data_dict/output_dict 中多类别的 object_bbx_center/mask
554
+ 拆分出第 class_index 个类别的子数据,以便单独跑 NMS。
555
+ """
556
+ import copy
557
+ data_dict_single = {"ego": {}}
558
+ output_dict_single = {}
559
+
560
+ # 遍历所有 cav (late fusion)
561
+ for cav_id in data_dict.keys():
562
+ cav_content = data_dict[cav_id]
563
+ cav_output = output_dict[cav_id]
564
+
565
+ # 如果 object_bbx_center 是 [num_class, M, 7],mask 是 [num_class, M]
566
+ # 拆分出 cav_idx = class_index 这一路
567
+ single_box_center = cav_content["object_bbx_center"][class_index, ...]
568
+ single_mask = cav_content["object_bbx_mask"][class_index, ...]
569
+ # object_ids 如果是按类别存储的list���可按 class_index 取即可
570
+ # 如果合并一起,需要自己额外做记录
571
+ if isinstance(cav_content["object_ids"], list):
572
+ single_ids = cav_content["object_ids"][class_index]
573
+ else:
574
+ single_ids = cav_content["object_ids"] # 或者看具体储存方式
575
+
576
+ # 类似地,对网络输出 cls_preds, reg_preds_multiclass 都要取第 class_index 路
577
+ # 具体看原网络 forward 的输出 shape
578
+ cls_preds_single = cav_output["cls_preds"][
579
+ :, class_index : class_index + 1, :, :
580
+ ] # e.g. [B,1,H,W]
581
+ reg_preds_single = cav_output["reg_preds_multiclass"][
582
+ :, class_index, :, :
583
+ ] # [B,H,W,Nreg]
584
+
585
+ # 构造新的 data_dict_single / output_dict_single
586
+ data_dict_single[cav_id] = copy.deepcopy(cav_content)
587
+ data_dict_single[cav_id]["object_bbx_center"] = single_box_center[None, ...] # 保留一个 batch 维
588
+ data_dict_single[cav_id]["object_bbx_mask"] = single_mask[None, ...]
589
+ data_dict_single[cav_id]["object_ids"] = single_ids
590
+
591
+ output_dict_single[cav_id] = copy.deepcopy(cav_output)
592
+ output_dict_single[cav_id]["cls_preds"] = cls_preds_single
593
+ output_dict_single[cav_id]["reg_preds"] = reg_preds_single
594
+
595
+ return data_dict_single, output_dict_single
596
+
597
+ ###################################################
598
+ # 一些工具函数(和原 LateFusionDataset/中间类一致) #
599
+ ###################################################
600
+ def x1_to_x2(self, lidar_pose1, lidar_pose2):
601
+ """
602
+ 位姿变换矩阵, 与 opencood.utils.transformation_utils.x1_to_x2 一致。
603
+ """
604
+ return x1_to_x2(lidar_pose1, lidar_pose2)
605
+
606
+ def list_to_tensor(self, data_list):
607
+ """
608
+ 简易把 list of np.array 变成 torch.Tensor, 做 batch 拼接用。
609
+ """
610
+ import numpy as np
611
+ import torch
612
+ if len(data_list) == 0:
613
+ return None
614
+ arr = np.stack(data_list, axis=0)
615
+ return torch.from_numpy(arr)
616
+
617
+ def unsqueeze_to_batch(self, arr):
618
+ """
619
+ 如果 arr 是 np.ndarray,就转成 [1, ...],再转成 torch。
620
+ """
621
+ import numpy as np
622
+ import torch
623
+ if isinstance(arr, np.ndarray):
624
+ arr = arr[None, ...] # 在前面加一个 batch 维
625
+ arr = torch.from_numpy(arr)
626
+ elif isinstance(arr, torch.Tensor) and arr.dim() == 2:
627
+ # [M,7] -> [1,M,7]
628
+ arr = arr.unsqueeze(0)
629
+ return arr
630
+
631
+ return LateMultiFusionDataset
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/datasets/late_multiclass_fusion_dataset.py ADDED
@@ -0,0 +1,1233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # late fusion dataset
2
+ import random
3
+ import math
4
+ from collections import OrderedDict
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import copy
9
+ from icecream import ic
10
+ from PIL import Image
11
+ import pickle as pkl
12
+ from opencood.utils import box_utils as box_utils
13
+ from opencood.data_utils.pre_processor import build_preprocessor
14
+ from opencood.data_utils.post_processor import build_postprocessor
15
+ from opencood.utils.camera_utils import (
16
+ sample_augmentation,
17
+ img_transform,
18
+ normalize_img,
19
+ img_to_tensor,
20
+ )
21
+ from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
22
+ from opencood.utils.transformation_utils import x1_to_x2, x_to_world, get_pairwise_transformation
23
+ from opencood.utils.pose_utils import add_noise_data_dict, add_noise_data_dict_asymmetric
24
+ from opencood.utils.pcd_utils import (
25
+ mask_points_by_range,
26
+ mask_ego_points,
27
+ mask_ego_points_v2,
28
+ shuffle_points,
29
+ downsample_lidar_minimum,
30
+ )
31
+ from opencood.utils.common_utils import merge_features_to_dict
32
+
33
+ def getLatemulticlassFusionDataset(cls):
34
+ """
35
+ cls: the Basedataset.
36
+ """
37
+ class LatemulticlassFusionDataset(cls):
38
+ def __init__(self, params, visualize, train=True):
39
+ super().__init__(params, visualize, train)
40
+ self.anchor_box = self.post_processor.generate_anchor_box()
41
+ self.anchor_box_torch = torch.from_numpy(self.anchor_box)
42
+
43
+ self.heterogeneous = False
44
+ if 'heter' in params:
45
+ self.heterogeneous = True
46
+
47
+ self.multiclass = params['model']['args']['multi_class']
48
+
49
+ self.proj_first = False if 'proj_first' not in params['fusion']['args']\
50
+ else params['fusion']['args']['proj_first']
51
+
52
+ # self.proj_first = False
53
+ self.supervise_single = True if ('supervise_single' in params['model']['args'] and params['model']['args']['supervise_single']) \
54
+ else False
55
+ # self.supervise_single = False
56
+ self.online_eval_only = False
57
+
58
+
59
+ def __getitem__(self, idx, extra_source=None, data_dir=None):
60
+
61
+ if data_dir is not None:
62
+ extra_source=1
63
+
64
+ object_bbx_center_list = []
65
+ object_bbx_mask_list = []
66
+ object_id_dict = {}
67
+
68
+ object_bbx_center_list_single = []
69
+ object_bbx_mask_list_single = []
70
+
71
+ gt_object_bbx_center_list = []
72
+ gt_object_bbx_mask_list = []
73
+ gt_object_id_dict = {}
74
+
75
+ gt_object_bbx_center_list_single = []
76
+ gt_object_bbx_mask_list_single = []
77
+
78
+ output_dict = {}
79
+ for tpe in ['all', 0, 1, 3]:
80
+ output_single_class = self.__getitem_single_class__(idx, tpe, extra_source, data_dir)
81
+ output_dict[tpe] = output_single_class
82
+ if tpe == 'all' and extra_source is None:
83
+ continue
84
+ elif tpe == 'all' and extra_source is not None:
85
+ break
86
+ object_bbx_center_list.append(output_single_class['ego']['object_bbx_center'])
87
+ object_bbx_mask_list.append(output_single_class['ego']['object_bbx_mask'])
88
+ object_id_dict[tpe] = output_single_class['ego']['object_ids']
89
+
90
+ gt_object_bbx_center_list.append(output_single_class['ego']['gt_object_bbx_center'])
91
+ gt_object_bbx_mask_list.append(output_single_class['ego']['gt_object_bbx_mask'])
92
+ gt_object_id_dict[tpe] = output_single_class['ego']['gt_object_ids']
93
+
94
+ if self.multiclass and extra_source is None:
95
+ output_dict['all']['ego']['object_bbx_center'] = np.stack(object_bbx_center_list, axis=0)
96
+ output_dict['all']['ego']['object_bbx_mask'] = np.stack(object_bbx_mask_list, axis=0)
97
+ output_dict['all']['ego']['object_ids'] = object_id_dict
98
+
99
+ output_dict['all']['ego']['gt_object_bbx_center'] = np.stack(gt_object_bbx_center_list, axis=0)
100
+ output_dict['all']['ego']['gt_object_bbx_mask'] = np.stack(gt_object_bbx_mask_list, axis=0)
101
+ output_dict['all']['ego']['gt_object_ids'] = gt_object_id_dict
102
+
103
+
104
+ return output_dict['all']
105
+
106
+ def __getitem_single_class__(self, idx, tpe=None, extra_source=None, data_dir=None):
107
+
108
+ if extra_source is None and data_dir is None:
109
+ base_data_dict = self.retrieve_base_data(idx, tpe) ## {id:{'ego':True/False, 'params': {'lidar_pose','speed','vehicles','ego_pos',...}, 'lidar_np': array (N,4)}}
110
+ elif data_dir is not None:
111
+ base_data_dict = self.retrieve_base_data(idx=None, tpe=tpe, data_dir=data_dir)
112
+ elif extra_source is not None:
113
+ base_data_dict = self.retrieve_base_data(idx=None, tpe=tpe, extra_source=extra_source)
114
+
115
+ # base_data_dict = add_noise_data_dict(base_data_dict,self.params['noise_setting'])
116
+ base_data_dict = add_noise_data_dict_asymmetric(base_data_dict,self.params['noise_setting'])
117
+ processed_data_dict = OrderedDict()
118
+ processed_data_dict['ego'] = {}
119
+ ego_id = -1
120
+ ego_lidar_pose = []
121
+ ego_cav_base = None
122
+ cav_id_list = []
123
+ lidar_pose_list = []
124
+ too_far = []
125
+ # first find the ego vehicle's lidar pose
126
+ for cav_id, cav_content in base_data_dict.items():
127
+
128
+ if cav_content['ego']:
129
+ ego_id = cav_id
130
+ ego_lidar_pose = cav_content['params']['lidar_pose']
131
+ ego_lidar_pose_clean = cav_content['params']['lidar_pose_clean']
132
+ ego_cav_base = cav_content
133
+ break
134
+
135
+ assert ego_id != -1
136
+ assert len(ego_lidar_pose) > 0
137
+
138
+ agents_image_inputs = []
139
+ processed_features = []
140
+ object_stack = []
141
+ object_mask_stack = []
142
+ object_id_stack = []
143
+
144
+ gt_object_stack = []
145
+ gt_object_mask_stack = []
146
+ gt_object_id_stack = []
147
+
148
+ single_label_list = []
149
+ single_object_bbx_center_list = []
150
+ single_object_bbx_mask_list = []
151
+ too_far = []
152
+ lidar_pose_list = []
153
+ lidar_pose_clean_list = []
154
+ cav_id_list = []
155
+ projected_lidar_clean_list = [] # disconet
156
+
157
+ if self.visualize:
158
+ projected_lidar_stack = []
159
+
160
+ # loop over all CAVs to process information
161
+ for cav_id, selected_cav_base in base_data_dict.items():
162
+ distance = \
163
+ math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
164
+ ego_lidar_pose[0]) ** 2 + (
165
+ selected_cav_base['params'][
166
+ 'lidar_pose'][1] - ego_lidar_pose[
167
+ 1]) ** 2)
168
+ if distance > self.params['comm_range']:
169
+ too_far.append(cav_id)
170
+ continue
171
+ cav_id_list.append(cav_id)
172
+ lidar_pose_list.append(selected_cav_base['params']['lidar_pose'])
173
+ lidar_pose_clean_list.append(selected_cav_base['params']['lidar_pose_clean'])
174
+
175
+ for cav_id in too_far:
176
+ base_data_dict.pop(cav_id)
177
+
178
+ pairwise_t_matrix = \
179
+ get_pairwise_transformation(base_data_dict,
180
+ self.max_cav,
181
+ self.proj_first)
182
+ cav_num = len(cav_id_list)
183
+ cav_id_list_newname = []
184
+
185
+ lidar_poses = np.array(lidar_pose_list).reshape(-1, 6) # [N_cav, 6]
186
+ lidar_poses_clean = np.array(lidar_pose_clean_list).reshape(-1, 6) # [N_cav, 6]
187
+
188
+ for cav_id in cav_id_list:
189
+ selected_cav_base = base_data_dict[cav_id]
190
+ # find the transformation matrix from current cav to ego.
191
+ cav_lidar_pose = selected_cav_base['params']['lidar_pose']
192
+ transformation_matrix = x1_to_x2(cav_lidar_pose, ego_lidar_pose)
193
+ cav_lidar_pose_clean = selected_cav_base['params']['lidar_pose_clean']
194
+ transformation_matrix_clean = x1_to_x2(cav_lidar_pose_clean, ego_lidar_pose_clean)
195
+
196
+ selected_cav_processed = \
197
+ self.get_item_single_car(selected_cav_base,
198
+ ego_cav_base,
199
+ tpe,
200
+ extra_source!=None)
201
+ selected_cav_processed.update({'transformation_matrix': transformation_matrix,
202
+ 'transformation_matrix_clean': transformation_matrix_clean})
203
+ if extra_source is None:
204
+ object_stack.append(selected_cav_processed['object_bbx_center'])
205
+ object_mask_stack.append(selected_cav_processed['object_bbx_mask'])
206
+ object_id_stack += selected_cav_processed['object_ids']
207
+
208
+
209
+ gt_object_stack.append(selected_cav_processed['gt_object_bbx_center'])
210
+ gt_object_mask_stack.append(selected_cav_processed['gt_object_bbx_mask'])
211
+ gt_object_id_stack += selected_cav_processed['gt_object_ids']
212
+
213
+ if tpe == 'all':
214
+
215
+ if self.load_lidar_file:
216
+ processed_features.append(
217
+ selected_cav_processed['processed_lidar'])
218
+
219
+ if self.load_camera_file:
220
+ agents_image_inputs.append(
221
+ selected_cav_processed['image_inputs'])
222
+
223
+ if self.visualize:
224
+ projected_lidar_stack.append(
225
+ selected_cav_processed['projected_lidar'])
226
+
227
+
228
+ if self.supervise_single and extra_source is None :
229
+ single_label_list.append(selected_cav_processed['single_label_dict'])
230
+ single_object_bbx_center_list.append(selected_cav_processed['single_object_bbx_center'])
231
+ single_object_bbx_mask_list.append(selected_cav_processed['single_object_bbx_mask'])
232
+
233
+ update_cav = "ego" if cav_id == ego_id else cav_id
234
+ processed_data_dict.update({update_cav: selected_cav_processed})
235
+ cav_id_list_newname.append(update_cav)
236
+
237
+ if self.supervise_single and extra_source is None:
238
+ single_label_dicts = {}
239
+ if tpe == 'all':
240
+ # unused label
241
+ if False:
242
+ single_label_dicts = self.post_processor.collate_batch(single_label_list)
243
+ single_object_bbx_center = torch.from_numpy(np.array(single_object_bbx_center_list))
244
+ single_object_bbx_mask = torch.from_numpy(np.array(single_object_bbx_mask_list))
245
+ processed_data_dict['ego'].update({
246
+ "single_label_dict_torch": single_label_dicts,
247
+ "single_object_bbx_center_torch": single_object_bbx_center,
248
+ "single_object_bbx_mask_torch": single_object_bbx_mask,
249
+ })
250
+
251
+ # heterogeneous
252
+ if self.heterogeneous:
253
+ processed_data_dict['ego']['idx'] = idx
254
+ processed_data_dict['ego']['cav_list'] = cav_id_list_newname
255
+
256
+ if extra_source is None:
257
+ unique_indices = \
258
+ [object_id_stack.index(x) for x in set(object_id_stack)]
259
+ object_stack = np.vstack(object_stack)
260
+ object_mask_stack = np.concatenate(object_mask_stack)
261
+ object_stack = object_stack[unique_indices]
262
+ object_mask_stack = object_mask_stack[unique_indices]
263
+
264
+ # make sure bounding boxes across all frames have the same number
265
+ object_bbx_center = \
266
+ np.zeros((self.params['postprocess']['max_num'], 7))
267
+ mask = np.zeros(self.params['postprocess']['max_num'])
268
+ object_bbx_center[:object_stack.shape[0], :] = object_stack
269
+ mask[:object_mask_stack.shape[0]] = object_mask_stack
270
+ # mask[:object_mask_stack.shape[0]] = 1
271
+
272
+ gt_unique_indices = \
273
+ [gt_object_id_stack.index(x) for x in set(gt_object_id_stack)]
274
+ gt_object_stack = np.vstack(gt_object_stack)
275
+ gt_object_mask_stack = np.concatenate(gt_object_mask_stack)
276
+ gt_object_stack = gt_object_stack[gt_unique_indices]
277
+ gt_object_mask_stack = gt_object_mask_stack[unique_indices]
278
+
279
+ # make sure bounding boxes across all frames have the same number
280
+ gt_object_bbx_center = \
281
+ np.zeros((self.params['postprocess']['max_num'], 7))
282
+ gt_mask = np.zeros(self.params['postprocess']['max_num'])
283
+ gt_object_bbx_center[:gt_object_stack.shape[0], :] = gt_object_stack
284
+ gt_mask[:gt_object_mask_stack.shape[0]] = gt_object_mask_stack
285
+ # gt_mask[:gt_object_mask_stack.shape[0]] = 1
286
+
287
+ processed_data_dict['ego'].update(
288
+ {'object_bbx_center': object_bbx_center, # (100,7)
289
+ 'object_bbx_mask': mask, # (100,)
290
+ 'object_ids': [object_id_stack[i] for i in unique_indices],
291
+ }
292
+ )
293
+
294
+ # generate targets label
295
+ label_dict = {}
296
+ # if tpe == 'all':
297
+ # unused label
298
+ if extra_source is None:
299
+ label_dict = \
300
+ self.post_processor.generate_label(
301
+ gt_box_center=object_bbx_center,
302
+ anchors=self.anchor_box,
303
+ mask=mask)
304
+ gt_label_dict = \
305
+ self.post_processor.generate_label(
306
+ gt_box_center=gt_object_bbx_center,
307
+ anchors=self.anchor_box,
308
+ mask=gt_mask)
309
+
310
+
311
+ processed_data_dict['ego'].update(
312
+ {'gt_object_bbx_center': gt_object_bbx_center, # (100,7)
313
+ 'gt_object_bbx_mask': gt_mask, # (100,)
314
+ 'gt_object_ids': [gt_object_id_stack[i] for i in gt_unique_indices],
315
+ 'gt_label_dict': gt_label_dict})
316
+
317
+ processed_data_dict['ego'].update(
318
+ {
319
+ 'anchor_box': self.anchor_box,
320
+ 'label_dict': label_dict,
321
+ 'cav_num': cav_num,
322
+ 'pairwise_t_matrix': pairwise_t_matrix,
323
+ 'lidar_poses_clean': lidar_poses_clean,
324
+ 'lidar_poses': lidar_poses})
325
+
326
+ if tpe == 'all':
327
+ if self.load_lidar_file:
328
+ merged_feature_dict = merge_features_to_dict(processed_features)
329
+ processed_data_dict['ego'].update({'processed_lidar': merged_feature_dict})
330
+
331
+ if self.load_camera_file:
332
+ merged_image_inputs_dict = merge_features_to_dict(agents_image_inputs, merge='stack')
333
+ processed_data_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
334
+
335
+ if self.visualize:
336
+ processed_data_dict['ego'].update({'origin_lidar':
337
+ # projected_lidar_stack})
338
+ np.vstack(
339
+ projected_lidar_stack)})
340
+ processed_data_dict['ego'].update({'lidar_len': [len(projected_lidar_stack[i]) for i in range(len(projected_lidar_stack))]})
341
+
342
+
343
+ processed_data_dict['ego'].update({'sample_idx': idx,
344
+ 'cav_id_list': cav_id_list})
345
+
346
+ img_front_list = []
347
+ img_left_list = []
348
+ img_right_list = []
349
+ BEV_list = []
350
+
351
+ if self.visualize:
352
+ for car_id in base_data_dict:
353
+ if not base_data_dict[car_id]['ego'] == True:
354
+ continue
355
+ if 'rgb_front' in base_data_dict[car_id] and 'rgb_left' in base_data_dict[car_id] and 'rgb_right' in base_data_dict[car_id] and 'BEV' in base_data_dict[car_id] :
356
+ img_front_list.append(base_data_dict[car_id]['rgb_front'])
357
+ img_left_list.append(base_data_dict[car_id]['rgb_left'])
358
+ img_right_list.append(base_data_dict[car_id]['rgb_right'])
359
+ BEV_list.append(base_data_dict[car_id]['BEV'])
360
+ processed_data_dict['ego'].update({'img_front': img_front_list,
361
+ 'img_left': img_left_list,
362
+ 'img_right': img_right_list,
363
+ 'BEV': BEV_list})
364
+ processed_data_dict['ego'].update({'scene_dict': base_data_dict['car_0']['scene_dict'],
365
+ 'frame_id': base_data_dict['car_0']['frame_id'],
366
+ })
367
+
368
+ return processed_data_dict
369
+
370
+ def get_item_single_car(self, selected_cav_base, ego_cav_base, tpe, online_eval=False):
371
+ """
372
+ Process a single CAV's information for the train/test pipeline.
373
+
374
+
375
+ Parameters
376
+ ----------
377
+ selected_cav_base : dict
378
+ The dictionary contains a single CAV's raw information.
379
+ including 'params', 'camera_data'
380
+
381
+ Returns
382
+ -------
383
+ selected_cav_processed : dict
384
+ The dictionary contains the cav's processed information.
385
+ """
386
+ selected_cav_processed = {}
387
+
388
+ if not online_eval:
389
+ # label
390
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center_single(
391
+ [selected_cav_base], selected_cav_base["params"]["lidar_pose_clean"]
392
+ )
393
+
394
+ ego_pose, ego_pose_clean = ego_cav_base['params']['lidar_pose'], ego_cav_base['params']['lidar_pose_clean']
395
+
396
+
397
+ # calculate the transformation matrix
398
+ transformation_matrix = \
399
+ x1_to_x2(selected_cav_base['params']['lidar_pose'],
400
+ ego_pose) # T_ego_cav
401
+ transformation_matrix_clean = \
402
+ x1_to_x2(selected_cav_base['params']['lidar_pose_clean'],
403
+ ego_pose_clean)
404
+
405
+ # lidar
406
+ if tpe == 'all':
407
+ if self.load_lidar_file or self.visualize:
408
+ lidar_np = selected_cav_base['lidar_np']
409
+ lidar_np = shuffle_points(lidar_np)
410
+ lidar_np = mask_points_by_range(lidar_np,
411
+ self.params['preprocess'][
412
+ 'cav_lidar_range'])
413
+ # remove points that hit ego vehicle
414
+ lidar_np = mask_ego_points_v2(lidar_np)
415
+
416
+ # data augmentation, seems very important for single agent training, because lack of data diversity.
417
+ # only work for lidar modality in training.
418
+ if not self.heterogeneous and not online_eval:
419
+ lidar_np, object_bbx_center, object_bbx_mask = \
420
+ self.augment(lidar_np, object_bbx_center, object_bbx_mask)
421
+
422
+ projected_lidar = \
423
+ box_utils.project_points_by_matrix_torch(lidar_np[:, :3], transformation_matrix)
424
+
425
+ if self.proj_first:
426
+ lidar_np[:, :3] = projected_lidar
427
+
428
+ if self.visualize:
429
+ # filter lidar
430
+ selected_cav_processed.update({'projected_lidar': projected_lidar})
431
+
432
+ lidar_dict = self.pre_processor.preprocess(lidar_np)
433
+ selected_cav_processed.update({'processed_lidar': lidar_dict})
434
+
435
+ if self.visualize:
436
+ selected_cav_processed.update({'origin_lidar': lidar_np})
437
+
438
+ if not online_eval:
439
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center(
440
+ [selected_cav_base], selected_cav_base['params']['lidar_pose']
441
+ )
442
+
443
+ gt_object_bbx_center, gt_object_bbx_mask, gt_object_ids = self.generate_object_center(
444
+ [selected_cav_base], selected_cav_base['params']['lidar_pose']
445
+ )
446
+
447
+ label_dict = self.post_processor.generate_label(
448
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
449
+ )
450
+
451
+ gt_label_dict = self.post_processor.generate_label(
452
+ gt_box_center=gt_object_bbx_center, anchors=self.anchor_box, mask=gt_object_bbx_mask
453
+ )
454
+
455
+ selected_cav_processed.update({
456
+ "single_label_dict": label_dict,
457
+ "single_object_bbx_center": object_bbx_center,
458
+ "single_object_bbx_mask": object_bbx_mask})
459
+
460
+ # camera
461
+ if tpe == 'all':
462
+ if self.load_camera_file:
463
+ # adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/data.py
464
+ camera_data_list = selected_cav_base["camera_data"]
465
+
466
+ params = selected_cav_base["params"]
467
+ imgs = []
468
+ rots = []
469
+ trans = []
470
+ intrins = []
471
+ extrinsics = [] # cam_to_lidar
472
+ post_rots = []
473
+ post_trans = []
474
+
475
+ for idx, img in enumerate(camera_data_list):
476
+ camera_to_lidar, camera_intrinsic = self.get_ext_int(params, idx)
477
+
478
+ intrin = torch.from_numpy(camera_intrinsic)
479
+ rot = torch.from_numpy(
480
+ camera_to_lidar[:3, :3]
481
+ ) # R_wc, we consider world-coord is the lidar-coord
482
+ tran = torch.from_numpy(camera_to_lidar[:3, 3]) # T_wc
483
+
484
+ post_rot = torch.eye(2)
485
+ post_tran = torch.zeros(2)
486
+
487
+ img_src = [img]
488
+
489
+ # depth
490
+ if self.load_depth_file:
491
+ depth_img = selected_cav_base["depth_data"][idx]
492
+ img_src.append(depth_img)
493
+ else:
494
+ depth_img = None
495
+
496
+ # data augmentation
497
+ resize, resize_dims, crop, flip, rotate = sample_augmentation(
498
+ self.data_aug_conf, self.train
499
+ )
500
+ img_src, post_rot2, post_tran2 = img_transform(
501
+ img_src,
502
+ post_rot,
503
+ post_tran,
504
+ resize=resize,
505
+ resize_dims=resize_dims,
506
+ crop=crop,
507
+ flip=flip,
508
+ rotate=rotate,
509
+ )
510
+ # for convenience, make augmentation matrices 3x3
511
+ post_tran = torch.zeros(3)
512
+ post_rot = torch.eye(3)
513
+ post_tran[:2] = post_tran2
514
+ post_rot[:2, :2] = post_rot2
515
+
516
+ img_src[0] = normalize_img(img_src[0])
517
+ if self.load_depth_file:
518
+ img_src[1] = img_to_tensor(img_src[1]) * 255
519
+
520
+ imgs.append(torch.cat(img_src, dim=0))
521
+ intrins.append(intrin)
522
+ extrinsics.append(torch.from_numpy(camera_to_lidar))
523
+ rots.append(rot)
524
+ trans.append(tran)
525
+ post_rots.append(post_rot)
526
+ post_trans.append(post_tran)
527
+
528
+ selected_cav_processed.update(
529
+ {
530
+ "image_inputs":
531
+ {
532
+ "imgs": torch.stack(imgs), # [N, 3or4, H, W]
533
+ "intrins": torch.stack(intrins),
534
+ "extrinsics": torch.stack(extrinsics),
535
+ "rots": torch.stack(rots),
536
+ "trans": torch.stack(trans),
537
+ "post_rots": torch.stack(post_rots),
538
+ "post_trans": torch.stack(post_trans),
539
+ }
540
+ }
541
+ )
542
+
543
+ selected_cav_processed.update({"anchor_box": self.anchor_box})
544
+
545
+ if not online_eval:
546
+ object_bbx_center, object_bbx_mask, object_ids = self.generate_object_center([selected_cav_base],
547
+ ego_pose_clean)
548
+
549
+ gt_object_bbx_center, gt_object_bbx_mask, gt_object_ids = self.generate_object_center([selected_cav_base],
550
+ ego_pose_clean)
551
+ selected_cav_processed.update(
552
+ {
553
+ "object_bbx_center": object_bbx_center,
554
+ "object_bbx_mask": object_bbx_mask,
555
+ "object_ids": object_ids,
556
+ }
557
+ )
558
+
559
+ selected_cav_processed.update(
560
+ {
561
+ "gt_object_bbx_center": gt_object_bbx_center[gt_object_bbx_mask == 1],
562
+ "gt_object_bbx_mask": gt_object_bbx_mask,
563
+ "gt_object_ids": gt_object_ids
564
+ }
565
+ )
566
+
567
+ # generate targets label
568
+ label_dict = self.post_processor.generate_label(
569
+ gt_box_center=object_bbx_center, anchors=self.anchor_box, mask=object_bbx_mask
570
+ )
571
+ selected_cav_processed.update({"label_dict": label_dict})
572
+
573
+ selected_cav_processed.update(
574
+ {
575
+ 'transformation_matrix': transformation_matrix,
576
+ 'transformation_matrix_clean': transformation_matrix_clean
577
+ }
578
+ )
579
+
580
+ return selected_cav_processed
581
+
582
+
583
+ def collate_batch_train(self, batch, online_eval_only=False):
584
+ """
585
+ Customized collate function for pytorch dataloader during training
586
+ for early and late fusion dataset.
587
+
588
+ Parameters
589
+ ----------
590
+ batch : dict
591
+
592
+ Returns
593
+ -------
594
+ batch : dict
595
+ Reformatted batch.
596
+ """
597
+ # during training, we only care about ego.
598
+ output_dict = {'ego': {}}
599
+
600
+ object_bbx_center = []
601
+ object_bbx_mask = []
602
+ processed_lidar_list = []
603
+ label_dict_list = []
604
+ origin_lidar = []
605
+
606
+ gt_object_bbx_center = []
607
+ gt_object_bbx_mask = []
608
+ gt_object_ids = []
609
+ gt_label_dict_list = []
610
+ record_len = []
611
+
612
+ object_ids = []
613
+ image_inputs_list = []
614
+ # used to record different scenario
615
+ record_len = []
616
+ label_dict_list = []
617
+ lidar_pose_list = []
618
+ origin_lidar = []
619
+ lidar_len = []
620
+ lidar_pose_clean_list = []
621
+
622
+ # heterogeneous
623
+ lidar_agent_list = []
624
+
625
+ # pairwise transformation matrix
626
+ pairwise_t_matrix_list = []
627
+
628
+ # disconet
629
+ teacher_processed_lidar_list = []
630
+
631
+ # image
632
+ img_front = []
633
+ img_left = []
634
+ img_right = []
635
+ BEV = []
636
+
637
+ dict_list = []
638
+
639
+ if self.supervise_single:
640
+ pos_equal_one_single = []
641
+ neg_equal_one_single = []
642
+ targets_single = []
643
+ object_bbx_center_single = []
644
+ object_bbx_mask_single = []
645
+
646
+ for i in range(len(batch)):
647
+ ego_dict = batch[i]['ego']
648
+
649
+ if not online_eval_only:
650
+ object_bbx_center.append(ego_dict['object_bbx_center'])
651
+ object_bbx_mask.append(ego_dict['object_bbx_mask'])
652
+ object_ids.append(ego_dict['object_ids'])
653
+
654
+ gt_object_bbx_center.append(ego_dict['gt_object_bbx_center'])
655
+ gt_object_bbx_mask.append(ego_dict['gt_object_bbx_mask'])
656
+
657
+ gt_object_ids.append(ego_dict['gt_object_ids'])
658
+
659
+ label_dict_list.append(ego_dict['label_dict'])
660
+
661
+ gt_label_dict_list.append(ego_dict['gt_label_dict'])
662
+
663
+ else:
664
+ object_ids.append(None)
665
+ gt_object_ids.append(None)
666
+
667
+ lidar_pose_list.append(ego_dict['lidar_poses']) # ego_dict['lidar_pose'] is np.ndarray [N,6]
668
+ lidar_pose_clean_list.append(ego_dict['lidar_poses_clean'])
669
+
670
+ if self.load_lidar_file:
671
+ processed_lidar_list.append(ego_dict['processed_lidar'])
672
+ if self.load_camera_file:
673
+ image_inputs_list.append(ego_dict['image_inputs']) # different cav_num, ego_dict['image_inputs'] is dict.
674
+
675
+ record_len.append(ego_dict['cav_num'])
676
+ pairwise_t_matrix_list.append(ego_dict['pairwise_t_matrix'])
677
+
678
+ dict_list.append([ego_dict['scene_dict'], ego_dict['frame_id']])
679
+
680
+ if self.visualize:
681
+ origin_lidar.append(ego_dict['origin_lidar'])
682
+ # lidar_len.append(ego_dict['lidar_len'])
683
+ if len(ego_dict['img_front']) > 0 and len(ego_dict['img_right']) > 0 and len(ego_dict['img_left']) > 0 and len(ego_dict['BEV']) > 0:
684
+ img_front.append(ego_dict['img_front'][0])
685
+ img_left.append(ego_dict['img_left'][0])
686
+ img_right.append(ego_dict['img_right'][0])
687
+ BEV.append(ego_dict['BEV'][0])
688
+
689
+ if self.supervise_single and not online_eval_only:
690
+ # unused label
691
+ if False:
692
+ pos_equal_one_single.append(ego_dict['single_label_dict_torch']['pos_equal_one'])
693
+ neg_equal_one_single.append(ego_dict['single_label_dict_torch']['neg_equal_one'])
694
+ targets_single.append(ego_dict['single_label_dict_torch']['targets'])
695
+ object_bbx_center_single.append(ego_dict['single_object_bbx_center_torch'])
696
+ object_bbx_mask_single.append(ego_dict['single_object_bbx_mask_torch'])
697
+
698
+ # heterogeneous
699
+ if self.heterogeneous:
700
+ lidar_agent_list.append(ego_dict['lidar_agent'])
701
+
702
+ # convert to numpy, (B, max_num, 7)
703
+ if not online_eval_only:
704
+ object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
705
+ object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))
706
+ gt_object_bbx_center = torch.from_numpy(np.array(gt_object_bbx_center))
707
+ gt_object_bbx_mask = torch.from_numpy(np.array(gt_object_bbx_mask))
708
+ else:
709
+ object_bbx_center = None
710
+ object_bbx_mask = None
711
+ gt_object_bbx_center = None
712
+ gt_object_bbx_mask = None
713
+
714
+
715
+ # unused label
716
+ label_torch_dict = {}
717
+ if False:
718
+ label_torch_dict = \
719
+ self.post_processor.collate_batch(label_dict_list)
720
+
721
+ record_len = torch.from_numpy(np.array(record_len))
722
+ record_len = torch.from_numpy(np.array(record_len, dtype=int))
723
+ pairwise_t_matrix = torch.from_numpy(np.array(pairwise_t_matrix_list))
724
+ label_torch_dict['record_len'] = record_len
725
+ label_torch_dict['pairwise_t_matrix'] = pairwise_t_matrix
726
+ # for centerpoint
727
+ if not online_eval_only:
728
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
729
+ 'object_bbx_mask': object_bbx_mask})
730
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
731
+ 'object_bbx_mask': object_bbx_mask,})
732
+ output_dict['ego'].update({
733
+ 'anchor_box': torch.from_numpy(self.anchor_box),
734
+ 'label_dict': label_torch_dict,
735
+ 'record_len': record_len,
736
+ 'pairwise_t_matrix': pairwise_t_matrix})
737
+ if self.visualize:
738
+ origin_lidar = \
739
+ np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
740
+ origin_lidar = torch.from_numpy(origin_lidar)
741
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
742
+
743
+ if self.load_lidar_file:
744
+ merged_feature_dict = merge_features_to_dict(processed_lidar_list)
745
+ if self.heterogeneous:
746
+ lidar_agent = np.concatenate(lidar_agent_list)
747
+ lidar_agent_idx = lidar_agent.nonzero()[0].tolist()
748
+ for k, v in merged_feature_dict.items(): # 'voxel_features' 'voxel_num_points' 'voxel_coords'
749
+ merged_feature_dict[k] = [v[index] for index in lidar_agent_idx]
750
+
751
+ if not self.heterogeneous or (self.heterogeneous and sum(lidar_agent) != 0):
752
+ processed_lidar_torch_dict = \
753
+ self.pre_processor.collate_batch(merged_feature_dict)
754
+ output_dict['ego'].update({'processed_lidar': processed_lidar_torch_dict})
755
+
756
+ if self.load_camera_file:
757
+ # collate ego camera information
758
+ imgs_batch = []
759
+ rots_batch = []
760
+ trans_batch = []
761
+ intrins_batch = []
762
+ extrinsics_batch = []
763
+ post_trans_batch = []
764
+ post_rots_batch = []
765
+ for i in range(len(batch)):
766
+ ego_dict = batch[i]["ego"]["image_inputs"]
767
+ imgs_batch.append(ego_dict["imgs"])
768
+ rots_batch.append(ego_dict["rots"])
769
+ trans_batch.append(ego_dict["trans"])
770
+ intrins_batch.append(ego_dict["intrins"])
771
+ extrinsics_batch.append(ego_dict["extrinsics"])
772
+ post_trans_batch.append(ego_dict["post_trans"])
773
+ post_rots_batch.append(ego_dict["post_rots"])
774
+
775
+ output_dict["ego"].update({
776
+ "image_inputs":
777
+ {
778
+ "imgs": torch.stack(imgs_batch), # [B, N, C, H, W]
779
+ "rots": torch.stack(rots_batch),
780
+ "trans": torch.stack(trans_batch),
781
+ "intrins": torch.stack(intrins_batch),
782
+ "post_trans": torch.stack(post_trans_batch),
783
+ "post_rots": torch.stack(post_rots_batch),
784
+ }
785
+ }
786
+ )
787
+
788
+ merged_image_inputs_dict = merge_features_to_dict(image_inputs_list, merge='cat')
789
+
790
+ if self.heterogeneous:
791
+ lidar_agent = np.concatenate(lidar_agent_list)
792
+ camera_agent = 1 - lidar_agent
793
+ camera_agent_idx = camera_agent.nonzero()[0].tolist()
794
+ if sum(camera_agent) != 0:
795
+ for k, v in merged_image_inputs_dict.items(): # 'imgs' 'rots' 'trans' ...
796
+ merged_image_inputs_dict[k] = torch.stack([v[index] for index in camera_agent_idx])
797
+
798
+ if not self.heterogeneous or (self.heterogeneous and sum(camera_agent) != 0):
799
+ output_dict['ego'].update({'image_inputs': merged_image_inputs_dict})
800
+
801
+ record_len = torch.from_numpy(np.array(record_len, dtype=int))
802
+ pairwise_t_matrix = torch.from_numpy(np.array(pairwise_t_matrix_list))
803
+ label_torch_dict['record_len'] = record_len
804
+ label_torch_dict['pairwise_t_matrix'] = pairwise_t_matrix
805
+ lidar_pose = torch.from_numpy(np.concatenate(lidar_pose_list, axis=0))
806
+ lidar_pose_clean = torch.from_numpy(np.concatenate(lidar_pose_clean_list, axis=0))
807
+
808
+ if not online_eval_only:
809
+ label_torch_dict = \
810
+ self.post_processor.collate_batch(label_dict_list)
811
+
812
+ gt_label_torch_dict = \
813
+ self.post_processor.collate_batch(gt_label_dict_list)
814
+
815
+ # for centerpoint
816
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
817
+ 'object_bbx_mask': object_bbx_mask})
818
+
819
+ gt_label_torch_dict.update({'gt_object_bbx_center': gt_object_bbx_center,
820
+ 'gt_object_bbx_mask': gt_object_bbx_mask})
821
+ else:
822
+ gt_label_torch_dict = {}
823
+
824
+ gt_label_torch_dict['pairwise_t_matrix'] = pairwise_t_matrix
825
+ gt_label_torch_dict['record_len'] = record_len
826
+
827
+ # object id is only used during inference, where batch size is 1.
828
+ # so here we only get the first element.
829
+ output_dict['ego'].update({'object_bbx_center': object_bbx_center,
830
+ 'object_bbx_mask': object_bbx_mask,
831
+ 'record_len': record_len,
832
+ 'label_dict': label_torch_dict,
833
+ 'object_ids': object_ids[0],
834
+ 'pairwise_t_matrix': pairwise_t_matrix,
835
+ 'lidar_pose_clean': lidar_pose_clean,
836
+ 'lidar_pose': lidar_pose,
837
+ 'anchor_box': self.anchor_box_torch})
838
+
839
+ output_dict['ego'].update({'gt_object_bbx_center': gt_object_bbx_center,
840
+ 'gt_object_bbx_mask': gt_object_bbx_mask,
841
+ 'gt_label_dict': gt_label_torch_dict,
842
+ 'gt_object_ids': gt_object_ids[0]})
843
+
844
+ output_dict['ego'].update({'dict_list': dict_list})
845
+ output_dict['ego'].update({'record_len': record_len,
846
+ 'pairwise_t_matrix': pairwise_t_matrix
847
+ })
848
+
849
+ if self.visualize:
850
+ origin_lidar = torch.from_numpy(np.array(origin_lidar))
851
+ output_dict['ego'].update({'origin_lidar': origin_lidar})
852
+ output_dict['ego'].update({'img_front': img_front})
853
+ output_dict['ego'].update({'img_right': img_right})
854
+ output_dict['ego'].update({'img_left': img_left})
855
+ output_dict['ego'].update({'BEV': BEV})
856
+
857
+ if self.supervise_single and not online_eval_only:
858
+ output_dict['ego'].update({
859
+ "label_dict_single":{
860
+ # "pos_equal_one": torch.cat(pos_equal_one_single, dim=0),
861
+ # "neg_equal_one": torch.cat(neg_equal_one_single, dim=0),
862
+ # "targets": torch.cat(targets_single, dim=0),
863
+ # for centerpoint
864
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
865
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
866
+ },
867
+ "object_bbx_center_single": torch.cat(object_bbx_center_single, dim=0),
868
+ "object_bbx_mask_single": torch.cat(object_bbx_mask_single, dim=0)
869
+ })
870
+
871
+ if self.heterogeneous:
872
+ output_dict['ego'].update({
873
+ "lidar_agent_record": torch.from_numpy(np.concatenate(lidar_agent_list)) # [0,1,1,0,1...]
874
+ })
875
+
876
+
877
+ return output_dict
878
+
879
+ def collate_batch_test(self, batch, online_eval_only=False):
880
+ """
881
+ Customized collate function for pytorch dataloader during testing
882
+ for late fusion dataset.
883
+
884
+ Parameters
885
+ ----------
886
+ batch : dict
887
+
888
+ Returns
889
+ -------
890
+ batch : dicn
891
+ Reformatted batch.
892
+ """
893
+ # currently, we only support batch size of 1 during testing
894
+ assert len(batch) <= 1, "Batch size 1 is required during testing!"
895
+
896
+ self.online_eval_only = online_eval_only
897
+
898
+ output_dict = self.collate_batch_train(batch, online_eval_only)
899
+ if output_dict is None:
900
+ return None
901
+
902
+ batch = batch[0]
903
+
904
+ if batch['ego']['anchor_box'] is not None:
905
+ output_dict['ego'].update({'anchor_box':
906
+ self.anchor_box_torch})
907
+
908
+ record_len = torch.from_numpy(np.array([batch['ego']['cav_num']]))
909
+ pairwise_t_matrix = torch.from_numpy(np.array([batch['ego']['pairwise_t_matrix']]))
910
+
911
+ output_dict['ego'].update({'record_len': record_len,
912
+ 'pairwise_t_matrix': pairwise_t_matrix
913
+ })
914
+
915
+ # heterogeneous
916
+ if self.heterogeneous:
917
+ idx = batch['ego']['idx']
918
+ cav_list = batch['ego']['cav_list'] # ['ego', '650' ..]
919
+ cav_num = len(batch)
920
+ lidar_agent, camera_agent = self.selector.select_agent(idx)
921
+ lidar_agent = lidar_agent[:cav_num] # [1,0,0,1,0]
922
+ lidar_agent_idx = lidar_agent.nonzero()[0].tolist()
923
+ lidar_agent_cav_id = [cav_list[index] for index in lidar_agent_idx] # ['ego', ...]
924
+
925
+
926
+ # for late fusion, we also need to stack the lidar for better
927
+ # visualization
928
+ if self.visualize:
929
+ projected_lidar_list = []
930
+ origin_lidar = []
931
+
932
+ for cav_id, cav_content in batch.items():
933
+ if cav_id != 'ego':
934
+ output_dict.update({cav_id: {}})
935
+ # output_dict.update({cav_id: {}})
936
+
937
+ if not online_eval_only:
938
+ object_bbx_center = \
939
+ torch.from_numpy(np.array([cav_content['object_bbx_center']]))
940
+ object_bbx_mask = \
941
+ torch.from_numpy(np.array([cav_content['object_bbx_mask']]))
942
+ object_ids = cav_content['object_ids']
943
+
944
+ # the anchor box is the same for all bounding boxes usually, thus
945
+ # we don't need the batch dimension.
946
+ output_dict[cav_id].update(
947
+ {"anchor_box": self.anchor_box_torch}
948
+ )
949
+
950
+ transformation_matrix = cav_content['transformation_matrix']
951
+
952
+ if self.visualize:
953
+ origin_lidar = [cav_content['origin_lidar']]
954
+ if (self.params['only_vis_ego'] is False) or (cav_id=='ego'):
955
+ projected_lidar = copy.deepcopy(cav_content['origin_lidar'])
956
+ projected_lidar[:, :3] = \
957
+ box_utils.project_points_by_matrix_torch(
958
+ projected_lidar[:, :3],
959
+ transformation_matrix)
960
+ projected_lidar_list.append(projected_lidar)
961
+
962
+
963
+ if self.load_lidar_file:
964
+ # processed lidar dictionary
965
+ #if 'processed_features' in cav_content.keys():
966
+
967
+ merged_feature_dict = merge_features_to_dict([cav_content['processed_lidar']])
968
+ processed_lidar_torch_dict = \
969
+ self.pre_processor.collate_batch(merged_feature_dict)
970
+ output_dict[cav_id].update({'processed_lidar': processed_lidar_torch_dict})
971
+
972
+ if self.load_camera_file:
973
+ imgs_batch = [cav_content["image_inputs"]["imgs"]]
974
+ rots_batch = [cav_content["image_inputs"]["rots"]]
975
+ trans_batch = [cav_content["image_inputs"]["trans"]]
976
+ intrins_batch = [cav_content["image_inputs"]["intrins"]]
977
+ extrinsics_batch = [cav_content["image_inputs"]["extrinsics"]]
978
+ post_trans_batch = [cav_content["image_inputs"]["post_trans"]]
979
+ post_rots_batch = [cav_content["image_inputs"]["post_rots"]]
980
+
981
+ output_dict[cav_id].update({
982
+ "image_inputs":
983
+ {
984
+ "imgs": torch.stack(imgs_batch),
985
+ "rots": torch.stack(rots_batch),
986
+ "trans": torch.stack(trans_batch),
987
+ "intrins": torch.stack(intrins_batch),
988
+ "extrinsics": torch.stack(extrinsics_batch),
989
+ "post_trans": torch.stack(post_trans_batch),
990
+ "post_rots": torch.stack(post_rots_batch),
991
+ }
992
+ }
993
+ )
994
+
995
+ # heterogeneous
996
+ if self.heterogeneous:
997
+ if cav_id in lidar_agent_cav_id:
998
+ output_dict[cav_id].pop('image_inputs')
999
+ else:
1000
+ output_dict[cav_id].pop('processed_lidar')
1001
+
1002
+ if not online_eval_only:
1003
+ # label dictionary
1004
+ label_torch_dict = \
1005
+ self.post_processor.collate_batch([cav_content['label_dict']])
1006
+
1007
+ # for centerpoint
1008
+ label_torch_dict.update({'object_bbx_center': object_bbx_center,
1009
+ 'object_bbx_mask': object_bbx_mask})
1010
+
1011
+ # save the transformation matrix (4, 4) to ego vehicle
1012
+ transformation_matrix_torch = \
1013
+ torch.from_numpy(
1014
+ np.array(cav_content['transformation_matrix'])).float()
1015
+
1016
+ # late fusion training, no noise
1017
+ transformation_matrix_clean_torch = \
1018
+ torch.from_numpy(
1019
+ np.array(cav_content['transformation_matrix_clean'])).float()
1020
+
1021
+ if not online_eval_only:
1022
+ output_dict[cav_id].update({'object_bbx_center': object_bbx_center,
1023
+ 'object_bbx_mask': object_bbx_mask,
1024
+ 'label_dict': label_torch_dict,
1025
+ # 'record_len': record_len,
1026
+ 'object_ids': object_ids,})
1027
+ output_dict[cav_id].update({
1028
+ 'transformation_matrix': transformation_matrix_torch,
1029
+ 'transformation_matrix_clean': transformation_matrix_clean_torch})
1030
+
1031
+
1032
+ if 'cav_num' in cav_content.keys():
1033
+ record_len = torch.from_numpy(np.array([cav_content['cav_num']]))
1034
+ output_dict[cav_id].update({'record_len': record_len})
1035
+
1036
+ if 'pairwise_t_matrix' in cav_content.keys():
1037
+ pairwise_t_matrix = torch.from_numpy(np.array([cav_content['pairwise_t_matrix']]))
1038
+ output_dict[cav_id].update({'pairwise_t_matrix': pairwise_t_matrix})
1039
+
1040
+
1041
+
1042
+ if self.visualize:
1043
+ origin_lidar = \
1044
+ np.array(
1045
+ downsample_lidar_minimum(pcd_np_list=origin_lidar))
1046
+ origin_lidar = torch.from_numpy(origin_lidar)
1047
+ output_dict[cav_id].update({'origin_lidar': origin_lidar})
1048
+
1049
+ if self.visualize:
1050
+ projected_lidar_stack = [torch.from_numpy(
1051
+ np.vstack(projected_lidar_list))]
1052
+ output_dict['ego'].update({'origin_lidar': projected_lidar_stack})
1053
+
1054
+ output_dict['ego'].update({
1055
+ "sample_idx": batch['ego']['sample_idx'],
1056
+ "cav_id_list": batch['ego']['cav_id_list']
1057
+ })
1058
+ batch_record_len = output_dict['ego']['record_len']
1059
+
1060
+ for cav_id in output_dict.keys():
1061
+ if 'record_len' in output_dict[cav_id].keys():
1062
+ continue
1063
+ output_dict[cav_id].update({'record_len': batch_record_len})
1064
+
1065
+
1066
+ return output_dict
1067
+
1068
+
1069
+ def post_process(self, data_dict, output_dict):
1070
+ """
1071
+ Process the outputs of the model to 2D/3D bounding box.
1072
+
1073
+ Parameters
1074
+ ----------
1075
+ data_dict : dict
1076
+ The dictionary containing the origin input data of model.
1077
+
1078
+ output_dict :dict
1079
+ The dictionary containing the output of the model.
1080
+
1081
+ Returns
1082
+ -------
1083
+ pred_box_tensor : torch.Tensor
1084
+ The tensor of prediction bounding box after NMS.
1085
+ gt_box_tensor : torch.Tensor
1086
+ The tensor of gt bounding box.
1087
+ """
1088
+ pred_box_tensor, pred_score = self.post_processor.post_process(
1089
+ data_dict, output_dict
1090
+ )
1091
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
1092
+
1093
+ return pred_box_tensor, pred_score, gt_box_tensor
1094
+
1095
+ def post_process_no_fusion(self, data_dict, output_dict_ego):
1096
+ data_dict_ego = OrderedDict()
1097
+ data_dict_ego["ego"] = data_dict["ego"]
1098
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
1099
+
1100
+ pred_box_tensor, pred_score = self.post_processor.post_process(
1101
+ data_dict_ego, output_dict_ego
1102
+ )
1103
+ return pred_box_tensor, pred_score, gt_box_tensor
1104
+
1105
+ def post_process_multiclass(self, data_dict, output_dict, online_eval_only=False):
1106
+ """
1107
+ Process the outputs of the model to 2D/3D bounding box.
1108
+
1109
+ Parameters
1110
+ ----------
1111
+ data_dict : dict
1112
+ The dictionary containing the origin input data of model.
1113
+
1114
+ output_dict :dict
1115
+ The dictionary containing the output of the model.
1116
+
1117
+ Returns
1118
+ -------
1119
+ pred_box_tensor : torch.Tensor
1120
+ The tensor of prediction bounding box after NMS.
1121
+ gt_box_tensor : torch.Tensor
1122
+ The tensor of gt bounding box.
1123
+ """
1124
+
1125
+ if online_eval_only == False:
1126
+ online_eval_only = self.online_eval_only
1127
+
1128
+ num_class = output_dict['ego']['cls_preds'].shape[1]
1129
+ pred_box_tensor_list = []
1130
+ pred_score_list = []
1131
+ gt_box_tensor_list = []
1132
+
1133
+ num_list = [0,1,3]
1134
+
1135
+ for i in range(num_class):
1136
+ data_dict_single = copy.deepcopy(data_dict)
1137
+ gt_dict_single = {'ego': {}}
1138
+ gt_dict_single['ego'] = copy.deepcopy(data_dict['ego'])
1139
+ output_dict_single = copy.deepcopy(output_dict)
1140
+ if not online_eval_only:
1141
+ data_dict_single['ego']['object_bbx_center'] = data_dict['ego']['object_bbx_center'][:,i,:,:]
1142
+ data_dict_single['ego']['object_bbx_mask'] = data_dict['ego']['object_bbx_mask'][:,i,:]
1143
+ data_dict_single['ego']['object_ids'] = data_dict['ego']['object_ids'][num_list[i]]
1144
+ gt_dict_single['ego']['object_bbx_center'] = data_dict['ego']['gt_object_bbx_center'][:,i,:,:]
1145
+ gt_dict_single['ego']['object_bbx_mask'] = data_dict['ego']['gt_object_bbx_mask'][:,i,:]
1146
+ gt_dict_single['ego']['object_ids'] = data_dict['ego']['gt_object_ids'][num_list[i]]
1147
+
1148
+
1149
+ for cav in output_dict_single.keys():
1150
+ output_dict_single[cav]['cls_preds'] = output_dict[cav]['cls_preds'][:,i:i+1,:,:]
1151
+ output_dict_single[cav]['reg_preds'] = output_dict[cav]['reg_preds_multiclass'][:,i,:,:]
1152
+
1153
+ pred_box_tensor, pred_score = \
1154
+ self.post_processor.post_process(data_dict_single, output_dict_single)
1155
+
1156
+ if not online_eval_only:
1157
+ gt_box_tensor = self.post_processor.generate_gt_bbx(gt_dict_single)
1158
+ else:
1159
+ gt_box_tensor = None
1160
+
1161
+ pred_box_tensor_list.append(pred_box_tensor)
1162
+ pred_score_list.append(pred_score)
1163
+ gt_box_tensor_list.append(gt_box_tensor)
1164
+
1165
+ return pred_box_tensor_list, pred_score_list, gt_box_tensor_list
1166
+
1167
+ def post_process_multiclass_no_fusion(self, data_dict, output_dict_ego, online_eval_only=False):
1168
+ """
1169
+ Process the outputs of the model to 2D/3D bounding box.
1170
+
1171
+ Parameters
1172
+ ----------
1173
+ data_dict : dict
1174
+ The dictionary containing the origin input data of model.
1175
+
1176
+ output_dict :dict
1177
+ The dictionary containing the output of the model.
1178
+
1179
+ Returns
1180
+ -------
1181
+ pred_box_tensor : torch.Tensor
1182
+ The tensor of prediction bounding box after NMS.
1183
+ gt_box_tensor : torch.Tensor
1184
+ The tensor of gt bounding box.
1185
+ """
1186
+
1187
+ online_eval_only = self.online_eval_only
1188
+
1189
+ num_class = data_dict['ego']['object_bbx_center'].shape[1]
1190
+
1191
+
1192
+ pred_box_tensor_list = []
1193
+ pred_score_list = []
1194
+ gt_box_tensor_list = []
1195
+
1196
+ num_list = [0,1,3]
1197
+
1198
+ for i in range(num_class):
1199
+ data_dict_single = copy.deepcopy(data_dict)
1200
+ gt_dict_single = {'ego': {}}
1201
+ gt_dict_single['ego'] = copy.deepcopy(data_dict['ego'])
1202
+ output_dict_single = copy.deepcopy(output_dict_ego)
1203
+ data_dict_single['ego']['object_bbx_center'] = data_dict['ego']['object_bbx_center'][:,i,:,:]
1204
+ data_dict_single['ego']['object_bbx_mask'] = data_dict['ego']['object_bbx_mask'][:,i,:]
1205
+ data_dict_single['ego']['object_ids'] = data_dict['ego']['object_ids'][num_list[i]]
1206
+ gt_dict_single['ego']['object_bbx_center'] = data_dict['ego']['gt_object_bbx_center'][:,i,:,:]
1207
+ gt_dict_single['ego']['object_bbx_mask'] = data_dict['ego']['gt_object_bbx_mask'][:,i,:]
1208
+ gt_dict_single['ego']['object_ids'] = data_dict['ego']['gt_object_ids'][num_list[i]]
1209
+ output_dict_single['ego']['cls_preds'] = output_dict_ego['ego']['cls_preds'][:,i:i+1,:,:]
1210
+ output_dict_single['ego']['reg_preds'] = output_dict_ego['ego']['reg_preds_multiclass'][:,i,:,:]
1211
+ data_dict_single_ego = OrderedDict()
1212
+ data_dict_single_ego["ego"] = data_dict_single["ego"]
1213
+ pred_box_tensor, pred_score = \
1214
+ self.post_processor.post_process(data_dict_single_ego, output_dict_single)
1215
+ gt_box_tensor = self.post_processor.generate_gt_bbx(gt_dict_single)
1216
+
1217
+
1218
+ pred_box_tensor_list.append(pred_box_tensor)
1219
+ pred_score_list.append(pred_score)
1220
+ gt_box_tensor_list.append(gt_box_tensor)
1221
+
1222
+ return pred_box_tensor_list, pred_score_list, gt_box_tensor_list
1223
+
1224
+ def post_process_no_fusion_uncertainty(self, data_dict, output_dict_ego):
1225
+ data_dict_ego = OrderedDict()
1226
+ data_dict_ego['ego'] = data_dict['ego']
1227
+ gt_box_tensor = self.post_processor.generate_gt_bbx(data_dict)
1228
+
1229
+ pred_box_tensor, pred_score, uncertainty = \
1230
+ self.post_processor.post_process(data_dict_ego, output_dict_ego, return_uncertainty=True)
1231
+ return pred_box_tensor, pred_score, gt_box_tensor, uncertainty
1232
+
1233
+ return LatemulticlassFusionDataset
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Runsheng Xu <[email protected]>
3
+ # License: TDG-Attribution-NonCommercial-NoDistrib
4
+
5
+ from opencood.data_utils.post_processor.voxel_postprocessor import VoxelPostprocessor
6
+ from opencood.data_utils.post_processor.bev_postprocessor import BevPostprocessor
7
+ from opencood.data_utils.post_processor.ciassd_postprocessor import CiassdPostprocessor
8
+ from opencood.data_utils.post_processor.fpvrcnn_postprocessor import FpvrcnnPostprocessor
9
+ from opencood.data_utils.post_processor.uncertainty_voxel_postprocessor import UncertaintyVoxelPostprocessor
10
+
11
+ __all__ = {
12
+ 'VoxelPostprocessor': VoxelPostprocessor,
13
+ 'BevPostprocessor': BevPostprocessor,
14
+ 'CiassdPostprocessor': CiassdPostprocessor,
15
+ 'FpvrcnnPostprocessor': FpvrcnnPostprocessor,
16
+ 'UncertaintyVoxelPostprocessor': UncertaintyVoxelPostprocessor,
17
+ }
18
+
19
+
20
+ def build_postprocessor(anchor_cfg, train):
21
+ process_method_name = anchor_cfg['core_method']
22
+ anchor_generator = __all__[process_method_name](
23
+ anchor_params=anchor_cfg,
24
+ train=train
25
+ )
26
+
27
+ return anchor_generator
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (979 Bytes). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/base_postprocessor.cpython-37.pyc ADDED
Binary file (13.8 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/bev_postprocessor.cpython-37.pyc ADDED
Binary file (11.6 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/ciassd_postprocessor.cpython-37.pyc ADDED
Binary file (4.26 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/fpvrcnn_postprocessor.cpython-37.pyc ADDED
Binary file (5.71 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/uncertainty_voxel_postprocessor.cpython-37.pyc ADDED
Binary file (5.61 kB). View file
 
v2xverse_late_multiclass_2025_01_28_08_49_56/scripts/data_utils/post_processor/__pycache__/voxel_postprocessor.cpython-37.pyc ADDED
Binary file (10.6 kB). View file