heheyas commited on
Commit
cfb7702
·
1 Parent(s): f5c8d4d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +50 -0
  2. configs/ae/video.yaml +35 -0
  3. configs/embedder/clip_image.yaml +8 -0
  4. configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +104 -0
  5. configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml +105 -0
  6. configs/example_training/imagenet-f8_cond.yaml +185 -0
  7. configs/example_training/toy/cifar10_cond.yaml +98 -0
  8. configs/example_training/toy/mnist.yaml +79 -0
  9. configs/example_training/toy/mnist_cond.yaml +98 -0
  10. configs/example_training/toy/mnist_cond_discrete_eps.yaml +103 -0
  11. configs/example_training/toy/mnist_cond_l1_loss.yaml +99 -0
  12. configs/example_training/toy/mnist_cond_with_ema.yaml +100 -0
  13. configs/example_training/txt2img-clipl-legacy-ucg-training.yaml +182 -0
  14. configs/example_training/txt2img-clipl.yaml +184 -0
  15. configs/inference/sd_2_1.yaml +60 -0
  16. configs/inference/sd_2_1_768.yaml +60 -0
  17. configs/inference/sd_xl_base.yaml +93 -0
  18. configs/inference/sd_xl_refiner.yaml +86 -0
  19. configs/inference/svd.yaml +131 -0
  20. configs/inference/svd_image_decoder.yaml +114 -0
  21. configs/inference/svd_mv.yaml +202 -0
  22. mesh_recon/configs/neuralangelo-ortho-wmask.yaml +145 -0
  23. mesh_recon/configs/v3d.yaml +144 -0
  24. mesh_recon/configs/videonvs.yaml +144 -0
  25. mesh_recon/datasets/__init__.py +17 -0
  26. mesh_recon/datasets/blender.py +143 -0
  27. mesh_recon/datasets/colmap.py +332 -0
  28. mesh_recon/datasets/colmap_utils.py +295 -0
  29. mesh_recon/datasets/dtu.py +201 -0
  30. mesh_recon/datasets/fixed_poses/000_back_RT.txt +3 -0
  31. mesh_recon/datasets/fixed_poses/000_back_left_RT.txt +3 -0
  32. mesh_recon/datasets/fixed_poses/000_back_right_RT.txt +3 -0
  33. mesh_recon/datasets/fixed_poses/000_front_RT.txt +3 -0
  34. mesh_recon/datasets/fixed_poses/000_front_left_RT.txt +3 -0
  35. mesh_recon/datasets/fixed_poses/000_front_right_RT.txt +3 -0
  36. mesh_recon/datasets/fixed_poses/000_left_RT.txt +3 -0
  37. mesh_recon/datasets/fixed_poses/000_right_RT.txt +3 -0
  38. mesh_recon/datasets/fixed_poses/000_top_RT.txt +3 -0
  39. mesh_recon/datasets/ortho.py +287 -0
  40. mesh_recon/datasets/utils.py +0 -0
  41. mesh_recon/datasets/v3d.py +284 -0
  42. mesh_recon/datasets/videonvs.py +256 -0
  43. mesh_recon/datasets/videonvs_co3d.py +252 -0
  44. mesh_recon/launch.py +144 -0
  45. mesh_recon/mesh.py +845 -0
  46. mesh_recon/models/__init__.py +16 -0
  47. mesh_recon/models/base.py +32 -0
  48. mesh_recon/models/geometry.py +238 -0
  49. mesh_recon/models/nerf.py +161 -0
  50. mesh_recon/models/network_utils.py +215 -0
.gitignore ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extensions
2
+ *.egg-info
3
+ *.py[cod]
4
+
5
+ # envs
6
+ .pt13
7
+ .pt2
8
+
9
+ # directories
10
+ /checkpoints
11
+ /dist
12
+ /outputs
13
+ /build
14
+ /src
15
+ logs/
16
+ ckpts/
17
+ tmp/
18
+ lightning_logs/
19
+ images/
20
+ images*/
21
+ kb_configs/
22
+ debug_lvis.log
23
+ *.log
24
+ .cache/
25
+ redirects/
26
+ submits/
27
+ extern/
28
+ assets/images
29
+ output/
30
+ assets/scene
31
+ assets/GSO
32
+ assets/SD
33
+ spirals
34
+ *.zip
35
+ paper/
36
+ spirals_co3d/
37
+ scene_spirals/
38
+ blenders/
39
+ colmap_results/
40
+ depth_spirals/
41
+ recon/SIBR_viewers/
42
+ recon/assets/
43
+ mesh_recon/exp
44
+ mesh_recon/runs
45
+ mesh_recon/renders
46
+ mesh_recon/refined
47
+ *.png
48
+ *.pdf
49
+ *.npz
50
+ *.npy
configs/ae/video.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ target: sgm.models.autoencoder.AutoencodingEngine
2
+ params:
3
+ loss_config:
4
+ target: torch.nn.Identity
5
+ regularizer_config:
6
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
7
+ encoder_config:
8
+ target: sgm.modules.diffusionmodules.model.Encoder
9
+ params:
10
+ attn_type: vanilla
11
+ double_z: True
12
+ z_channels: 4
13
+ resolution: 256
14
+ in_channels: 3
15
+ out_ch: 3
16
+ ch: 128
17
+ ch_mult: [1, 2, 4, 4]
18
+ num_res_blocks: 2
19
+ attn_resolutions: []
20
+ dropout: 0.0
21
+ decoder_config:
22
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
23
+ params:
24
+ attn_type: vanilla
25
+ double_z: True
26
+ z_channels: 4
27
+ resolution: 256
28
+ in_channels: 3
29
+ out_ch: 3
30
+ ch: 128
31
+ ch_mult: [1, 2, 4, 4]
32
+ num_res_blocks: 2
33
+ attn_resolutions: []
34
+ dropout: 0.0
35
+ video_kernel_size: [3, 1, 1]
configs/embedder/clip_image.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
2
+ params:
3
+ n_cond_frames: 1
4
+ n_copies: 1
5
+ open_clip_embedding_config:
6
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
7
+ params:
8
+ freeze: True
configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: sgm.models.autoencoder.AutoencodingEngine
4
+ params:
5
+ input_key: jpg
6
+ monitor: val/rec_loss
7
+
8
+ loss_config:
9
+ target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
10
+ params:
11
+ perceptual_weight: 0.25
12
+ disc_start: 20001
13
+ disc_weight: 0.5
14
+ learn_logvar: True
15
+
16
+ regularization_weights:
17
+ kl_loss: 1.0
18
+
19
+ regularizer_config:
20
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
21
+
22
+ encoder_config:
23
+ target: sgm.modules.diffusionmodules.model.Encoder
24
+ params:
25
+ attn_type: none
26
+ double_z: True
27
+ z_channels: 4
28
+ resolution: 256
29
+ in_channels: 3
30
+ out_ch: 3
31
+ ch: 128
32
+ ch_mult: [1, 2, 4]
33
+ num_res_blocks: 4
34
+ attn_resolutions: []
35
+ dropout: 0.0
36
+
37
+ decoder_config:
38
+ target: sgm.modules.diffusionmodules.model.Decoder
39
+ params: ${model.params.encoder_config.params}
40
+
41
+ data:
42
+ target: sgm.data.dataset.StableDataModuleFromConfig
43
+ params:
44
+ train:
45
+ datapipeline:
46
+ urls:
47
+ - DATA-PATH
48
+ pipeline_config:
49
+ shardshuffle: 10000
50
+ sample_shuffle: 10000
51
+
52
+ decoders:
53
+ - pil
54
+
55
+ postprocessors:
56
+ - target: sdata.mappers.TorchVisionImageTransforms
57
+ params:
58
+ key: jpg
59
+ transforms:
60
+ - target: torchvision.transforms.Resize
61
+ params:
62
+ size: 256
63
+ interpolation: 3
64
+ - target: torchvision.transforms.ToTensor
65
+ - target: sdata.mappers.Rescaler
66
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
67
+ params:
68
+ h_key: height
69
+ w_key: width
70
+
71
+ loader:
72
+ batch_size: 8
73
+ num_workers: 4
74
+
75
+
76
+ lightning:
77
+ strategy:
78
+ target: pytorch_lightning.strategies.DDPStrategy
79
+ params:
80
+ find_unused_parameters: True
81
+
82
+ modelcheckpoint:
83
+ params:
84
+ every_n_train_steps: 5000
85
+
86
+ callbacks:
87
+ metrics_over_trainsteps_checkpoint:
88
+ params:
89
+ every_n_train_steps: 50000
90
+
91
+ image_logger:
92
+ target: main.ImageLogger
93
+ params:
94
+ enable_autocast: False
95
+ batch_frequency: 1000
96
+ max_images: 8
97
+ increase_log_steps: True
98
+
99
+ trainer:
100
+ devices: 0,
101
+ limit_val_batches: 50
102
+ benchmark: True
103
+ accumulate_grad_batches: 1
104
+ val_check_interval: 10000
configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: sgm.models.autoencoder.AutoencodingEngine
4
+ params:
5
+ input_key: jpg
6
+ monitor: val/loss/rec
7
+ disc_start_iter: 0
8
+
9
+ encoder_config:
10
+ target: sgm.modules.diffusionmodules.model.Encoder
11
+ params:
12
+ attn_type: vanilla-xformers
13
+ double_z: true
14
+ z_channels: 8
15
+ resolution: 256
16
+ in_channels: 3
17
+ out_ch: 3
18
+ ch: 128
19
+ ch_mult: [1, 2, 4, 4]
20
+ num_res_blocks: 2
21
+ attn_resolutions: []
22
+ dropout: 0.0
23
+
24
+ decoder_config:
25
+ target: sgm.modules.diffusionmodules.model.Decoder
26
+ params: ${model.params.encoder_config.params}
27
+
28
+ regularizer_config:
29
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
30
+
31
+ loss_config:
32
+ target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
33
+ params:
34
+ perceptual_weight: 0.25
35
+ disc_start: 20001
36
+ disc_weight: 0.5
37
+ learn_logvar: True
38
+
39
+ regularization_weights:
40
+ kl_loss: 1.0
41
+
42
+ data:
43
+ target: sgm.data.dataset.StableDataModuleFromConfig
44
+ params:
45
+ train:
46
+ datapipeline:
47
+ urls:
48
+ - DATA-PATH
49
+ pipeline_config:
50
+ shardshuffle: 10000
51
+ sample_shuffle: 10000
52
+
53
+ decoders:
54
+ - pil
55
+
56
+ postprocessors:
57
+ - target: sdata.mappers.TorchVisionImageTransforms
58
+ params:
59
+ key: jpg
60
+ transforms:
61
+ - target: torchvision.transforms.Resize
62
+ params:
63
+ size: 256
64
+ interpolation: 3
65
+ - target: torchvision.transforms.ToTensor
66
+ - target: sdata.mappers.Rescaler
67
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
68
+ params:
69
+ h_key: height
70
+ w_key: width
71
+
72
+ loader:
73
+ batch_size: 8
74
+ num_workers: 4
75
+
76
+
77
+ lightning:
78
+ strategy:
79
+ target: pytorch_lightning.strategies.DDPStrategy
80
+ params:
81
+ find_unused_parameters: True
82
+
83
+ modelcheckpoint:
84
+ params:
85
+ every_n_train_steps: 5000
86
+
87
+ callbacks:
88
+ metrics_over_trainsteps_checkpoint:
89
+ params:
90
+ every_n_train_steps: 50000
91
+
92
+ image_logger:
93
+ target: main.ImageLogger
94
+ params:
95
+ enable_autocast: False
96
+ batch_frequency: 1000
97
+ max_images: 8
98
+ increase_log_steps: True
99
+
100
+ trainer:
101
+ devices: 0,
102
+ limit_val_batches: 50
103
+ benchmark: True
104
+ accumulate_grad_batches: 1
105
+ val_check_interval: 10000
configs/example_training/imagenet-f8_cond.yaml ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - cls
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 256
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1024
42
+ transformer_depth: 1
43
+ context_dim: 1024
44
+ spatial_transformer_attn_type: softmax-xformers
45
+
46
+ conditioner_config:
47
+ target: sgm.modules.GeneralConditioner
48
+ params:
49
+ emb_models:
50
+ - is_trainable: True
51
+ input_key: cls
52
+ ucg_rate: 0.2
53
+ target: sgm.modules.encoders.modules.ClassEmbedder
54
+ params:
55
+ add_sequence_dim: True
56
+ embed_dim: 1024
57
+ n_classes: 1000
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.2
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.2
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [1, 2, 4, 4]
88
+ num_res_blocks: 2
89
+ attn_resolutions: []
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 5.0
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
131
+
132
+ decoders:
133
+ - pil
134
+
135
+ postprocessors:
136
+ - target: sdata.mappers.TorchVisionImageTransforms
137
+ params:
138
+ key: jpg # USER: you might wanna adapt this for your custom dataset
139
+ transforms:
140
+ - target: torchvision.transforms.Resize
141
+ params:
142
+ size: 256
143
+ interpolation: 3
144
+ - target: torchvision.transforms.ToTensor
145
+ - target: sdata.mappers.Rescaler
146
+
147
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
148
+ params:
149
+ h_key: height # USER: you might wanna adapt this for your custom dataset
150
+ w_key: width # USER: you might wanna adapt this for your custom dataset
151
+
152
+ loader:
153
+ batch_size: 64
154
+ num_workers: 6
155
+
156
+ lightning:
157
+ modelcheckpoint:
158
+ params:
159
+ every_n_train_steps: 5000
160
+
161
+ callbacks:
162
+ metrics_over_trainsteps_checkpoint:
163
+ params:
164
+ every_n_train_steps: 25000
165
+
166
+ image_logger:
167
+ target: main.ImageLogger
168
+ params:
169
+ disabled: False
170
+ enable_autocast: False
171
+ batch_frequency: 1000
172
+ max_images: 8
173
+ increase_log_steps: True
174
+ log_first_step: False
175
+ log_images_kwargs:
176
+ use_ema_scope: False
177
+ N: 8
178
+ n_rows: 2
179
+
180
+ trainer:
181
+ devices: 0,
182
+ benchmark: True
183
+ num_sanity_val_steps: 0
184
+ accumulate_grad_batches: 1
185
+ max_epochs: 1000
configs/example_training/toy/cifar10_cond.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 3
17
+ out_channels: 3
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_weighting_config:
45
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46
+ params:
47
+ sigma_data: 1.0
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50
+
51
+ sampler_config:
52
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53
+ params:
54
+ num_steps: 50
55
+
56
+ discretization_config:
57
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58
+
59
+ guider_config:
60
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61
+ params:
62
+ scale: 3.0
63
+
64
+ data:
65
+ target: sgm.data.cifar10.CIFAR10Loader
66
+ params:
67
+ batch_size: 512
68
+ num_workers: 1
69
+
70
+ lightning:
71
+ modelcheckpoint:
72
+ params:
73
+ every_n_train_steps: 5000
74
+
75
+ callbacks:
76
+ metrics_over_trainsteps_checkpoint:
77
+ params:
78
+ every_n_train_steps: 25000
79
+
80
+ image_logger:
81
+ target: main.ImageLogger
82
+ params:
83
+ disabled: False
84
+ batch_frequency: 1000
85
+ max_images: 64
86
+ increase_log_steps: True
87
+ log_first_step: False
88
+ log_images_kwargs:
89
+ use_ema_scope: False
90
+ N: 64
91
+ n_rows: 8
92
+
93
+ trainer:
94
+ devices: 0,
95
+ benchmark: True
96
+ num_sanity_val_steps: 0
97
+ accumulate_grad_batches: 1
98
+ max_epochs: 20
configs/example_training/toy/mnist.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+
24
+ first_stage_config:
25
+ target: sgm.models.autoencoder.IdentityFirstStage
26
+
27
+ loss_fn_config:
28
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
29
+ params:
30
+ loss_weighting_config:
31
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
32
+ params:
33
+ sigma_data: 1.0
34
+ sigma_sampler_config:
35
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
36
+
37
+ sampler_config:
38
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
39
+ params:
40
+ num_steps: 50
41
+
42
+ discretization_config:
43
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
44
+
45
+ data:
46
+ target: sgm.data.mnist.MNISTLoader
47
+ params:
48
+ batch_size: 512
49
+ num_workers: 1
50
+
51
+ lightning:
52
+ modelcheckpoint:
53
+ params:
54
+ every_n_train_steps: 5000
55
+
56
+ callbacks:
57
+ metrics_over_trainsteps_checkpoint:
58
+ params:
59
+ every_n_train_steps: 25000
60
+
61
+ image_logger:
62
+ target: main.ImageLogger
63
+ params:
64
+ disabled: False
65
+ batch_frequency: 1000
66
+ max_images: 64
67
+ increase_log_steps: False
68
+ log_first_step: False
69
+ log_images_kwargs:
70
+ use_ema_scope: False
71
+ N: 64
72
+ n_rows: 8
73
+
74
+ trainer:
75
+ devices: 0,
76
+ benchmark: True
77
+ num_sanity_val_steps: 0
78
+ accumulate_grad_batches: 1
79
+ max_epochs: 10
configs/example_training/toy/mnist_cond.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_weighting_config:
45
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46
+ params:
47
+ sigma_data: 1.0
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50
+
51
+ sampler_config:
52
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53
+ params:
54
+ num_steps: 50
55
+
56
+ discretization_config:
57
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58
+
59
+ guider_config:
60
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61
+ params:
62
+ scale: 3.0
63
+
64
+ data:
65
+ target: sgm.data.mnist.MNISTLoader
66
+ params:
67
+ batch_size: 512
68
+ num_workers: 1
69
+
70
+ lightning:
71
+ modelcheckpoint:
72
+ params:
73
+ every_n_train_steps: 5000
74
+
75
+ callbacks:
76
+ metrics_over_trainsteps_checkpoint:
77
+ params:
78
+ every_n_train_steps: 25000
79
+
80
+ image_logger:
81
+ target: main.ImageLogger
82
+ params:
83
+ disabled: False
84
+ batch_frequency: 1000
85
+ max_images: 16
86
+ increase_log_steps: True
87
+ log_first_step: False
88
+ log_images_kwargs:
89
+ use_ema_scope: False
90
+ N: 16
91
+ n_rows: 4
92
+
93
+ trainer:
94
+ devices: 0,
95
+ benchmark: True
96
+ num_sanity_val_steps: 0
97
+ accumulate_grad_batches: 1
98
+ max_epochs: 20
configs/example_training/toy/mnist_cond_discrete_eps.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
7
+ params:
8
+ num_idx: 1000
9
+
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12
+ discretization_config:
13
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
14
+
15
+ network_config:
16
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ in_channels: 1
19
+ out_channels: 1
20
+ model_channels: 32
21
+ attention_resolutions: []
22
+ num_res_blocks: 4
23
+ channel_mult: [1, 2, 2]
24
+ num_head_channels: 32
25
+ num_classes: sequential
26
+ adm_in_channels: 128
27
+
28
+ conditioner_config:
29
+ target: sgm.modules.GeneralConditioner
30
+ params:
31
+ emb_models:
32
+ - is_trainable: True
33
+ input_key: cls
34
+ ucg_rate: 0.2
35
+ target: sgm.modules.encoders.modules.ClassEmbedder
36
+ params:
37
+ embed_dim: 128
38
+ n_classes: 10
39
+
40
+ first_stage_config:
41
+ target: sgm.models.autoencoder.IdentityFirstStage
42
+
43
+ loss_fn_config:
44
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45
+ params:
46
+ loss_weighting_config:
47
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
50
+ params:
51
+ num_idx: 1000
52
+
53
+ discretization_config:
54
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
55
+
56
+ sampler_config:
57
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
58
+ params:
59
+ num_steps: 50
60
+
61
+ discretization_config:
62
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
63
+
64
+ guider_config:
65
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
66
+ params:
67
+ scale: 5.0
68
+
69
+ data:
70
+ target: sgm.data.mnist.MNISTLoader
71
+ params:
72
+ batch_size: 512
73
+ num_workers: 1
74
+
75
+ lightning:
76
+ modelcheckpoint:
77
+ params:
78
+ every_n_train_steps: 5000
79
+
80
+ callbacks:
81
+ metrics_over_trainsteps_checkpoint:
82
+ params:
83
+ every_n_train_steps: 25000
84
+
85
+ image_logger:
86
+ target: main.ImageLogger
87
+ params:
88
+ disabled: False
89
+ batch_frequency: 1000
90
+ max_images: 16
91
+ increase_log_steps: True
92
+ log_first_step: False
93
+ log_images_kwargs:
94
+ use_ema_scope: False
95
+ N: 16
96
+ n_rows: 4
97
+
98
+ trainer:
99
+ devices: 0,
100
+ benchmark: True
101
+ num_sanity_val_steps: 0
102
+ accumulate_grad_batches: 1
103
+ max_epochs: 20
configs/example_training/toy/mnist_cond_l1_loss.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_type: l1
45
+ loss_weighting_config:
46
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
47
+ params:
48
+ sigma_data: 1.0
49
+ sigma_sampler_config:
50
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
51
+
52
+ sampler_config:
53
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
54
+ params:
55
+ num_steps: 50
56
+
57
+ discretization_config:
58
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
59
+
60
+ guider_config:
61
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
62
+ params:
63
+ scale: 3.0
64
+
65
+ data:
66
+ target: sgm.data.mnist.MNISTLoader
67
+ params:
68
+ batch_size: 512
69
+ num_workers: 1
70
+
71
+ lightning:
72
+ modelcheckpoint:
73
+ params:
74
+ every_n_train_steps: 5000
75
+
76
+ callbacks:
77
+ metrics_over_trainsteps_checkpoint:
78
+ params:
79
+ every_n_train_steps: 25000
80
+
81
+ image_logger:
82
+ target: main.ImageLogger
83
+ params:
84
+ disabled: False
85
+ batch_frequency: 1000
86
+ max_images: 64
87
+ increase_log_steps: True
88
+ log_first_step: False
89
+ log_images_kwargs:
90
+ use_ema_scope: False
91
+ N: 64
92
+ n_rows: 8
93
+
94
+ trainer:
95
+ devices: 0,
96
+ benchmark: True
97
+ num_sanity_val_steps: 0
98
+ accumulate_grad_batches: 1
99
+ max_epochs: 20
configs/example_training/toy/mnist_cond_with_ema.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ use_ema: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12
+ params:
13
+ sigma_data: 1.0
14
+
15
+ network_config:
16
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ in_channels: 1
19
+ out_channels: 1
20
+ model_channels: 32
21
+ attention_resolutions: []
22
+ num_res_blocks: 4
23
+ channel_mult: [1, 2, 2]
24
+ num_head_channels: 32
25
+ num_classes: sequential
26
+ adm_in_channels: 128
27
+
28
+ conditioner_config:
29
+ target: sgm.modules.GeneralConditioner
30
+ params:
31
+ emb_models:
32
+ - is_trainable: True
33
+ input_key: cls
34
+ ucg_rate: 0.2
35
+ target: sgm.modules.encoders.modules.ClassEmbedder
36
+ params:
37
+ embed_dim: 128
38
+ n_classes: 10
39
+
40
+ first_stage_config:
41
+ target: sgm.models.autoencoder.IdentityFirstStage
42
+
43
+ loss_fn_config:
44
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45
+ params:
46
+ loss_weighting_config:
47
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48
+ params:
49
+ sigma_data: 1.0
50
+ sigma_sampler_config:
51
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
52
+
53
+ sampler_config:
54
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
55
+ params:
56
+ num_steps: 50
57
+
58
+ discretization_config:
59
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
60
+
61
+ guider_config:
62
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
63
+ params:
64
+ scale: 3.0
65
+
66
+ data:
67
+ target: sgm.data.mnist.MNISTLoader
68
+ params:
69
+ batch_size: 512
70
+ num_workers: 1
71
+
72
+ lightning:
73
+ modelcheckpoint:
74
+ params:
75
+ every_n_train_steps: 5000
76
+
77
+ callbacks:
78
+ metrics_over_trainsteps_checkpoint:
79
+ params:
80
+ every_n_train_steps: 25000
81
+
82
+ image_logger:
83
+ target: main.ImageLogger
84
+ params:
85
+ disabled: False
86
+ batch_frequency: 1000
87
+ max_images: 64
88
+ increase_log_steps: True
89
+ log_first_step: False
90
+ log_images_kwargs:
91
+ use_ema_scope: False
92
+ N: 64
93
+ n_rows: 8
94
+
95
+ trainer:
96
+ devices: 0,
97
+ benchmark: True
98
+ num_sanity_val_steps: 0
99
+ accumulate_grad_batches: 1
100
+ max_epochs: 20
configs/example_training/txt2img-clipl-legacy-ucg-training.yaml ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - txt
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1792
42
+ num_heads: 1
43
+ transformer_depth: 1
44
+ context_dim: 768
45
+ spatial_transformer_attn_type: softmax-xformers
46
+
47
+ conditioner_config:
48
+ target: sgm.modules.GeneralConditioner
49
+ params:
50
+ emb_models:
51
+ - is_trainable: True
52
+ input_key: txt
53
+ ucg_rate: 0.1
54
+ legacy_ucg_value: ""
55
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
56
+ params:
57
+ always_return_pooled: True
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.1
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.1
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [ 1, 2, 4, 4 ]
88
+ num_res_blocks: 2
89
+ attn_resolutions: [ ]
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 7.5
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
131
+
132
+ decoders:
133
+ - pil
134
+
135
+ postprocessors:
136
+ - target: sdata.mappers.TorchVisionImageTransforms
137
+ params:
138
+ key: jpg # USER: you might wanna adapt this for your custom dataset
139
+ transforms:
140
+ - target: torchvision.transforms.Resize
141
+ params:
142
+ size: 256
143
+ interpolation: 3
144
+ - target: torchvision.transforms.ToTensor
145
+ - target: sdata.mappers.Rescaler
146
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
147
+ # USER: you might wanna use non-default parameters due to your custom dataset
148
+
149
+ loader:
150
+ batch_size: 64
151
+ num_workers: 6
152
+
153
+ lightning:
154
+ modelcheckpoint:
155
+ params:
156
+ every_n_train_steps: 5000
157
+
158
+ callbacks:
159
+ metrics_over_trainsteps_checkpoint:
160
+ params:
161
+ every_n_train_steps: 25000
162
+
163
+ image_logger:
164
+ target: main.ImageLogger
165
+ params:
166
+ disabled: False
167
+ enable_autocast: False
168
+ batch_frequency: 1000
169
+ max_images: 8
170
+ increase_log_steps: True
171
+ log_first_step: False
172
+ log_images_kwargs:
173
+ use_ema_scope: False
174
+ N: 8
175
+ n_rows: 2
176
+
177
+ trainer:
178
+ devices: 0,
179
+ benchmark: True
180
+ num_sanity_val_steps: 0
181
+ accumulate_grad_batches: 1
182
+ max_epochs: 1000
configs/example_training/txt2img-clipl.yaml ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - txt
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1792
42
+ num_heads: 1
43
+ transformer_depth: 1
44
+ context_dim: 768
45
+ spatial_transformer_attn_type: softmax-xformers
46
+
47
+ conditioner_config:
48
+ target: sgm.modules.GeneralConditioner
49
+ params:
50
+ emb_models:
51
+ - is_trainable: True
52
+ input_key: txt
53
+ ucg_rate: 0.1
54
+ legacy_ucg_value: ""
55
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
56
+ params:
57
+ always_return_pooled: True
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.1
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.1
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [1, 2, 4, 4]
88
+ num_res_blocks: 2
89
+ attn_resolutions: []
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 7.5
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000
131
+
132
+
133
+ decoders:
134
+ - pil
135
+
136
+ postprocessors:
137
+ - target: sdata.mappers.TorchVisionImageTransforms
138
+ params:
139
+ key: jpg # USER: you might wanna adapt this for your custom dataset
140
+ transforms:
141
+ - target: torchvision.transforms.Resize
142
+ params:
143
+ size: 256
144
+ interpolation: 3
145
+ - target: torchvision.transforms.ToTensor
146
+ - target: sdata.mappers.Rescaler
147
+ # USER: you might wanna use non-default parameters due to your custom dataset
148
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
149
+ # USER: you might wanna use non-default parameters due to your custom dataset
150
+
151
+ loader:
152
+ batch_size: 64
153
+ num_workers: 6
154
+
155
+ lightning:
156
+ modelcheckpoint:
157
+ params:
158
+ every_n_train_steps: 5000
159
+
160
+ callbacks:
161
+ metrics_over_trainsteps_checkpoint:
162
+ params:
163
+ every_n_train_steps: 25000
164
+
165
+ image_logger:
166
+ target: main.ImageLogger
167
+ params:
168
+ disabled: False
169
+ enable_autocast: False
170
+ batch_frequency: 1000
171
+ max_images: 8
172
+ increase_log_steps: True
173
+ log_first_step: False
174
+ log_images_kwargs:
175
+ use_ema_scope: False
176
+ N: 8
177
+ n_rows: 2
178
+
179
+ trainer:
180
+ devices: 0,
181
+ benchmark: True
182
+ num_sanity_val_steps: 0
183
+ accumulate_grad_batches: 1
184
+ max_epochs: 1000
configs/inference/sd_2_1.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 320
24
+ attention_resolutions: [4, 2, 1]
25
+ num_res_blocks: 2
26
+ channel_mult: [1, 2, 4, 4]
27
+ num_head_channels: 64
28
+ use_linear_in_transformer: True
29
+ transformer_depth: 1
30
+ context_dim: 1024
31
+
32
+ conditioner_config:
33
+ target: sgm.modules.GeneralConditioner
34
+ params:
35
+ emb_models:
36
+ - is_trainable: False
37
+ input_key: txt
38
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
39
+ params:
40
+ freeze: true
41
+ layer: penultimate
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: val/rec_loss
48
+ ddconfig:
49
+ double_z: true
50
+ z_channels: 4
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult: [1, 2, 4, 4]
56
+ num_res_blocks: 2
57
+ attn_resolutions: []
58
+ dropout: 0.0
59
+ lossconfig:
60
+ target: torch.nn.Identity
configs/inference/sd_2_1_768.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 320
24
+ attention_resolutions: [4, 2, 1]
25
+ num_res_blocks: 2
26
+ channel_mult: [1, 2, 4, 4]
27
+ num_head_channels: 64
28
+ use_linear_in_transformer: True
29
+ transformer_depth: 1
30
+ context_dim: 1024
31
+
32
+ conditioner_config:
33
+ target: sgm.modules.GeneralConditioner
34
+ params:
35
+ emb_models:
36
+ - is_trainable: False
37
+ input_key: txt
38
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
39
+ params:
40
+ freeze: true
41
+ layer: penultimate
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: val/rec_loss
48
+ ddconfig:
49
+ double_z: true
50
+ z_channels: 4
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult: [1, 2, 4, 4]
56
+ num_res_blocks: 2
57
+ attn_resolutions: []
58
+ dropout: 0.0
59
+ lossconfig:
60
+ target: torch.nn.Identity
configs/inference/sd_xl_base.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ adm_in_channels: 2816
21
+ num_classes: sequential
22
+ use_checkpoint: True
23
+ in_channels: 4
24
+ out_channels: 4
25
+ model_channels: 320
26
+ attention_resolutions: [4, 2]
27
+ num_res_blocks: 2
28
+ channel_mult: [1, 2, 4]
29
+ num_head_channels: 64
30
+ use_linear_in_transformer: True
31
+ transformer_depth: [1, 2, 10]
32
+ context_dim: 2048
33
+ spatial_transformer_attn_type: softmax-xformers
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: txt
41
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
42
+ params:
43
+ layer: hidden
44
+ layer_idx: 11
45
+
46
+ - is_trainable: False
47
+ input_key: txt
48
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
49
+ params:
50
+ arch: ViT-bigG-14
51
+ version: laion2b_s39b_b160k
52
+ freeze: True
53
+ layer: penultimate
54
+ always_return_pooled: True
55
+ legacy: False
56
+
57
+ - is_trainable: False
58
+ input_key: original_size_as_tuple
59
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
+ params:
61
+ outdim: 256
62
+
63
+ - is_trainable: False
64
+ input_key: crop_coords_top_left
65
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
66
+ params:
67
+ outdim: 256
68
+
69
+ - is_trainable: False
70
+ input_key: target_size_as_tuple
71
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
72
+ params:
73
+ outdim: 256
74
+
75
+ first_stage_config:
76
+ target: sgm.models.autoencoder.AutoencoderKL
77
+ params:
78
+ embed_dim: 4
79
+ monitor: val/rec_loss
80
+ ddconfig:
81
+ attn_type: vanilla-xformers
82
+ double_z: true
83
+ z_channels: 4
84
+ resolution: 256
85
+ in_channels: 3
86
+ out_ch: 3
87
+ ch: 128
88
+ ch_mult: [1, 2, 4, 4]
89
+ num_res_blocks: 2
90
+ attn_resolutions: []
91
+ dropout: 0.0
92
+ lossconfig:
93
+ target: torch.nn.Identity
configs/inference/sd_xl_refiner.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ adm_in_channels: 2560
21
+ num_classes: sequential
22
+ use_checkpoint: True
23
+ in_channels: 4
24
+ out_channels: 4
25
+ model_channels: 384
26
+ attention_resolutions: [4, 2]
27
+ num_res_blocks: 2
28
+ channel_mult: [1, 2, 4, 4]
29
+ num_head_channels: 64
30
+ use_linear_in_transformer: True
31
+ transformer_depth: 4
32
+ context_dim: [1280, 1280, 1280, 1280]
33
+ spatial_transformer_attn_type: softmax-xformers
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: txt
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
42
+ params:
43
+ arch: ViT-bigG-14
44
+ version: laion2b_s39b_b160k
45
+ legacy: False
46
+ freeze: True
47
+ layer: penultimate
48
+ always_return_pooled: True
49
+
50
+ - is_trainable: False
51
+ input_key: original_size_as_tuple
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - is_trainable: False
57
+ input_key: crop_coords_top_left
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - is_trainable: False
63
+ input_key: aesthetic_score
64
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
65
+ params:
66
+ outdim: 256
67
+
68
+ first_stage_config:
69
+ target: sgm.models.autoencoder.AutoencoderKL
70
+ params:
71
+ embed_dim: 4
72
+ monitor: val/rec_loss
73
+ ddconfig:
74
+ attn_type: vanilla-xformers
75
+ double_z: true
76
+ z_channels: 4
77
+ resolution: 256
78
+ in_channels: 3
79
+ out_ch: 3
80
+ ch: 128
81
+ ch_mult: [1, 2, 4, 4]
82
+ num_res_blocks: 2
83
+ attn_resolutions: []
84
+ dropout: 0.0
85
+ lossconfig:
86
+ target: torch.nn.Identity
configs/inference/svd.yaml ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
15
+ params:
16
+ adm_in_channels: 768
17
+ num_classes: sequential
18
+ use_checkpoint: True
19
+ in_channels: 8
20
+ out_channels: 4
21
+ model_channels: 320
22
+ attention_resolutions: [4, 2, 1]
23
+ num_res_blocks: 2
24
+ channel_mult: [1, 2, 4, 4]
25
+ num_head_channels: 64
26
+ use_linear_in_transformer: True
27
+ transformer_depth: 1
28
+ context_dim: 1024
29
+ spatial_transformer_attn_type: softmax-xformers
30
+ extra_ff_mix_layer: True
31
+ use_spatial_context: True
32
+ merge_strategy: learned_with_images
33
+ video_kernel_size: [3, 1, 1]
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: cond_frames_without_noise
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42
+ params:
43
+ n_cond_frames: 1
44
+ n_copies: 1
45
+ open_clip_embedding_config:
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47
+ params:
48
+ freeze: True
49
+
50
+ - input_key: fps_id
51
+ is_trainable: False
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - input_key: motion_bucket_id
57
+ is_trainable: False
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - input_key: cond_frames
63
+ is_trainable: False
64
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
65
+ params:
66
+ disable_encoder_autocast: True
67
+ n_cond_frames: 1
68
+ n_copies: 1
69
+ is_ae: True
70
+ encoder_config:
71
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
72
+ params:
73
+ embed_dim: 4
74
+ monitor: val/rec_loss
75
+ ddconfig:
76
+ attn_type: vanilla-xformers
77
+ double_z: True
78
+ z_channels: 4
79
+ resolution: 256
80
+ in_channels: 3
81
+ out_ch: 3
82
+ ch: 128
83
+ ch_mult: [1, 2, 4, 4]
84
+ num_res_blocks: 2
85
+ attn_resolutions: []
86
+ dropout: 0.0
87
+ lossconfig:
88
+ target: torch.nn.Identity
89
+
90
+ - input_key: cond_aug
91
+ is_trainable: False
92
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93
+ params:
94
+ outdim: 256
95
+
96
+ first_stage_config:
97
+ target: sgm.models.autoencoder.AutoencodingEngine
98
+ params:
99
+ loss_config:
100
+ target: torch.nn.Identity
101
+ regularizer_config:
102
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
103
+ encoder_config:
104
+ target: sgm.modules.diffusionmodules.model.Encoder
105
+ params:
106
+ attn_type: vanilla
107
+ double_z: True
108
+ z_channels: 4
109
+ resolution: 256
110
+ in_channels: 3
111
+ out_ch: 3
112
+ ch: 128
113
+ ch_mult: [1, 2, 4, 4]
114
+ num_res_blocks: 2
115
+ attn_resolutions: []
116
+ dropout: 0.0
117
+ decoder_config:
118
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
119
+ params:
120
+ attn_type: vanilla
121
+ double_z: True
122
+ z_channels: 4
123
+ resolution: 256
124
+ in_channels: 3
125
+ out_ch: 3
126
+ ch: 128
127
+ ch_mult: [1, 2, 4, 4]
128
+ num_res_blocks: 2
129
+ attn_resolutions: []
130
+ dropout: 0.0
131
+ video_kernel_size: [3, 1, 1]
configs/inference/svd_image_decoder.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
15
+ params:
16
+ adm_in_channels: 768
17
+ num_classes: sequential
18
+ use_checkpoint: True
19
+ in_channels: 8
20
+ out_channels: 4
21
+ model_channels: 320
22
+ attention_resolutions: [4, 2, 1]
23
+ num_res_blocks: 2
24
+ channel_mult: [1, 2, 4, 4]
25
+ num_head_channels: 64
26
+ use_linear_in_transformer: True
27
+ transformer_depth: 1
28
+ context_dim: 1024
29
+ spatial_transformer_attn_type: softmax-xformers
30
+ extra_ff_mix_layer: True
31
+ use_spatial_context: True
32
+ merge_strategy: learned_with_images
33
+ video_kernel_size: [3, 1, 1]
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: cond_frames_without_noise
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42
+ params:
43
+ n_cond_frames: 1
44
+ n_copies: 1
45
+ open_clip_embedding_config:
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47
+ params:
48
+ freeze: True
49
+
50
+ - input_key: fps_id
51
+ is_trainable: False
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - input_key: motion_bucket_id
57
+ is_trainable: False
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - input_key: cond_frames
63
+ is_trainable: False
64
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
65
+ params:
66
+ disable_encoder_autocast: True
67
+ n_cond_frames: 1
68
+ n_copies: 1
69
+ is_ae: True
70
+ encoder_config:
71
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
72
+ params:
73
+ embed_dim: 4
74
+ monitor: val/rec_loss
75
+ ddconfig:
76
+ attn_type: vanilla-xformers
77
+ double_z: True
78
+ z_channels: 4
79
+ resolution: 256
80
+ in_channels: 3
81
+ out_ch: 3
82
+ ch: 128
83
+ ch_mult: [1, 2, 4, 4]
84
+ num_res_blocks: 2
85
+ attn_resolutions: []
86
+ dropout: 0.0
87
+ lossconfig:
88
+ target: torch.nn.Identity
89
+
90
+ - input_key: cond_aug
91
+ is_trainable: False
92
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93
+ params:
94
+ outdim: 256
95
+
96
+ first_stage_config:
97
+ target: sgm.models.autoencoder.AutoencoderKL
98
+ params:
99
+ embed_dim: 4
100
+ monitor: val/rec_loss
101
+ ddconfig:
102
+ attn_type: vanilla-xformers
103
+ double_z: True
104
+ z_channels: 4
105
+ resolution: 256
106
+ in_channels: 3
107
+ out_ch: 3
108
+ ch: 128
109
+ ch_mult: [1, 2, 4, 4]
110
+ num_res_blocks: 2
111
+ attn_resolutions: []
112
+ dropout: 0.0
113
+ lossconfig:
114
+ target: torch.nn.Identity
configs/inference/svd_mv.yaml ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-05
3
+ target: sgm.models.video_diffusion.DiffusionEngine
4
+ params:
5
+ ckpt_path: ckpts/svd_xt.safetensors
6
+ scale_factor: 0.18215
7
+ disable_first_stage_autocast: true
8
+ scheduler_config:
9
+ target: sgm.lr_scheduler.LambdaLinearScheduler
10
+ params:
11
+ warm_up_steps:
12
+ - 1
13
+ cycle_lengths:
14
+ - 10000000000000
15
+ f_start:
16
+ - 1.0e-06
17
+ f_max:
18
+ - 1.0
19
+ f_min:
20
+ - 1.0
21
+ denoiser_config:
22
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
23
+ params:
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
26
+ network_config:
27
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
28
+ params:
29
+ adm_in_channels: 768
30
+ num_classes: sequential
31
+ use_checkpoint: true
32
+ in_channels: 8
33
+ out_channels: 4
34
+ model_channels: 320
35
+ attention_resolutions:
36
+ - 4
37
+ - 2
38
+ - 1
39
+ num_res_blocks: 2
40
+ channel_mult:
41
+ - 1
42
+ - 2
43
+ - 4
44
+ - 4
45
+ num_head_channels: 64
46
+ use_linear_in_transformer: true
47
+ transformer_depth: 1
48
+ context_dim: 1024
49
+ spatial_transformer_attn_type: softmax-xformers
50
+ extra_ff_mix_layer: true
51
+ use_spatial_context: true
52
+ merge_strategy: learned_with_images
53
+ video_kernel_size:
54
+ - 3
55
+ - 1
56
+ - 1
57
+ conditioner_config:
58
+ target: sgm.modules.GeneralConditioner
59
+ params:
60
+ emb_models:
61
+ - is_trainable: false
62
+ ucg_rate: 0.2
63
+ input_key: cond_frames_without_noise
64
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
65
+ params:
66
+ n_cond_frames: 1
67
+ n_copies: 1
68
+ open_clip_embedding_config:
69
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
70
+ params:
71
+ freeze: true
72
+ - input_key: fps_id
73
+ is_trainable: true
74
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
75
+ params:
76
+ outdim: 256
77
+ - input_key: motion_bucket_id
78
+ is_trainable: true
79
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
80
+ params:
81
+ outdim: 256
82
+ - input_key: cond_frames
83
+ is_trainable: false
84
+ ucg_rate: 0.2
85
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
86
+ params:
87
+ disable_encoder_autocast: true
88
+ n_cond_frames: 1
89
+ n_copies: 1
90
+ is_ae: true
91
+ encoder_config:
92
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
93
+ params:
94
+ embed_dim: 4
95
+ monitor: val/rec_loss
96
+ ddconfig:
97
+ attn_type: vanilla-xformers
98
+ double_z: true
99
+ z_channels: 4
100
+ resolution: 256
101
+ in_channels: 3
102
+ out_ch: 3
103
+ ch: 128
104
+ ch_mult:
105
+ - 1
106
+ - 2
107
+ - 4
108
+ - 4
109
+ num_res_blocks: 2
110
+ attn_resolutions: []
111
+ dropout: 0.0
112
+ lossconfig:
113
+ target: torch.nn.Identity
114
+ - input_key: cond_aug
115
+ is_trainable: true
116
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
117
+ params:
118
+ outdim: 256
119
+ first_stage_config:
120
+ target: sgm.models.autoencoder.AutoencodingEngine
121
+ params:
122
+ loss_config:
123
+ target: torch.nn.Identity
124
+ regularizer_config:
125
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
126
+ encoder_config:
127
+ target: sgm.modules.diffusionmodules.model.Encoder
128
+ params:
129
+ attn_type: vanilla
130
+ double_z: true
131
+ z_channels: 4
132
+ resolution: 256
133
+ in_channels: 3
134
+ out_ch: 3
135
+ ch: 128
136
+ ch_mult:
137
+ - 1
138
+ - 2
139
+ - 4
140
+ - 4
141
+ num_res_blocks: 2
142
+ attn_resolutions: []
143
+ dropout: 0.0
144
+ decoder_config:
145
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
146
+ params:
147
+ attn_type: vanilla
148
+ double_z: true
149
+ z_channels: 4
150
+ resolution: 256
151
+ in_channels: 3
152
+ out_ch: 3
153
+ ch: 128
154
+ ch_mult:
155
+ - 1
156
+ - 2
157
+ - 4
158
+ - 4
159
+ num_res_blocks: 2
160
+ attn_resolutions: []
161
+ dropout: 0.0
162
+ video_kernel_size:
163
+ - 3
164
+ - 1
165
+ - 1
166
+ sampler_config:
167
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
168
+ params:
169
+ num_steps: 30
170
+ discretization_config:
171
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
172
+ params:
173
+ sigma_max: 700.0
174
+ guider_config:
175
+ target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
176
+ params:
177
+ max_scale: 2.5
178
+ min_scale: 1.0
179
+ num_frames: 24
180
+ loss_fn_config:
181
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
182
+ params:
183
+ batch2model_keys:
184
+ - num_video_frames
185
+ - image_only_indicator
186
+ loss_weighting_config:
187
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
188
+ params:
189
+ sigma_data: 1.0
190
+ sigma_sampler_config:
191
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
192
+ params:
193
+ p_mean: 0.3
194
+ p_std: 1.2
195
+ data:
196
+ target: sgm.data.objaverse.ObjaverseSpiralDataset
197
+ params:
198
+ root_dir: /mnt/mfs/zilong.chen/Downloads/objaverse-ndd-samples
199
+ random_front: true
200
+ batch_size: 2
201
+ num_workers: 16
202
+ cond_aug_mean: -0.0
mesh_recon/configs/neuralangelo-ortho-wmask.yaml ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ${basename:${dataset.scene}}
2
+ tag: ""
3
+ seed: 42
4
+
5
+ dataset:
6
+ name: ortho
7
+ root_dir: /home/xiaoxiao/Workplace/wonder3Dplus/outputs/joint-twice/aigc/cropsize-224-cfg1.0
8
+ cam_pose_dir: null
9
+ scene: scene_name
10
+ imSize: [1024, 1024] # should use larger res, otherwise the exported mesh has wrong colors
11
+ camera_type: ortho
12
+ apply_mask: true
13
+ camera_params: null
14
+ view_weights: [1.0, 0.8, 0.2, 1.0, 0.4, 0.7] #['front', 'front_right', 'right', 'back', 'left', 'front_left']
15
+ # view_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
16
+
17
+ model:
18
+ name: neus
19
+ radius: 1.0
20
+ num_samples_per_ray: 1024
21
+ train_num_rays: 256
22
+ max_train_num_rays: 8192
23
+ grid_prune: true
24
+ grid_prune_occ_thre: 0.001
25
+ dynamic_ray_sampling: true
26
+ batch_image_sampling: true
27
+ randomized: true
28
+ ray_chunk: 2048
29
+ cos_anneal_end: 20000
30
+ learned_background: false
31
+ background_color: black
32
+ variance:
33
+ init_val: 0.3
34
+ modulate: false
35
+ geometry:
36
+ name: volume-sdf
37
+ radius: ${model.radius}
38
+ feature_dim: 13
39
+ grad_type: finite_difference
40
+ finite_difference_eps: progressive
41
+ isosurface:
42
+ method: mc
43
+ resolution: 192
44
+ chunk: 2097152
45
+ threshold: 0.
46
+ xyz_encoding_config:
47
+ otype: ProgressiveBandHashGrid
48
+ n_levels: 10 # 12 modify
49
+ n_features_per_level: 2
50
+ log2_hashmap_size: 19
51
+ base_resolution: 32
52
+ per_level_scale: 1.3195079107728942
53
+ include_xyz: true
54
+ start_level: 4
55
+ start_step: 0
56
+ update_steps: 1000
57
+ mlp_network_config:
58
+ otype: VanillaMLP
59
+ activation: ReLU
60
+ output_activation: none
61
+ n_neurons: 64
62
+ n_hidden_layers: 1
63
+ sphere_init: true
64
+ sphere_init_radius: 0.5
65
+ weight_norm: true
66
+ texture:
67
+ name: volume-radiance
68
+ input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
69
+ dir_encoding_config:
70
+ otype: SphericalHarmonics
71
+ degree: 4
72
+ mlp_network_config:
73
+ otype: VanillaMLP
74
+ activation: ReLU
75
+ output_activation: none
76
+ n_neurons: 64
77
+ n_hidden_layers: 2
78
+ color_activation: sigmoid
79
+
80
+ system:
81
+ name: ortho-neus-system
82
+ loss:
83
+ lambda_rgb_mse: 0.5
84
+ lambda_rgb_l1: 0.
85
+ lambda_mask: 1.0
86
+ lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects
87
+ lambda_normal: 1.0 # cannot be too large
88
+ lambda_3d_normal_smooth: 1.0
89
+ # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
90
+ lambda_curvature: 0.
91
+ lambda_sparsity: 0.5
92
+ lambda_distortion: 0.0
93
+ lambda_distortion_bg: 0.0
94
+ lambda_opaque: 0.0
95
+ sparsity_scale: 100.0
96
+ geo_aware: true
97
+ rgb_p_ratio: 0.8
98
+ normal_p_ratio: 0.8
99
+ mask_p_ratio: 0.9
100
+ optimizer:
101
+ name: AdamW
102
+ args:
103
+ lr: 0.01
104
+ betas: [0.9, 0.99]
105
+ eps: 1.e-15
106
+ params:
107
+ geometry:
108
+ lr: 0.001
109
+ texture:
110
+ lr: 0.01
111
+ variance:
112
+ lr: 0.001
113
+ constant_steps: 500
114
+ scheduler:
115
+ name: SequentialLR
116
+ interval: step
117
+ milestones:
118
+ - ${system.constant_steps}
119
+ schedulers:
120
+ - name: ConstantLR
121
+ args:
122
+ factor: 1.0
123
+ total_iters: ${system.constant_steps}
124
+ - name: ExponentialLR
125
+ args:
126
+ gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
127
+
128
+ checkpoint:
129
+ save_top_k: -1
130
+ every_n_train_steps: ${trainer.max_steps}
131
+
132
+ export:
133
+ chunk_size: 2097152
134
+ export_vertex_color: True
135
+ ortho_scale: 1.35 #modify
136
+
137
+ trainer:
138
+ max_steps: 3000
139
+ log_every_n_steps: 100
140
+ num_sanity_val_steps: 0
141
+ val_check_interval: 4000
142
+ limit_train_batches: 1.0
143
+ limit_val_batches: 2
144
+ enable_progress_bar: true
145
+ precision: 16
mesh_recon/configs/v3d.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ${basename:${dataset.scene}}
2
+ tag: ""
3
+ seed: 42
4
+
5
+ dataset:
6
+ name: v3d
7
+ root_dir: ./spirals
8
+ cam_pose_dir: null
9
+ scene: pizza_man
10
+ apply_mask: true
11
+ train_split: train
12
+ test_split: train
13
+ val_split: train
14
+ img_wh: [1024, 1024]
15
+
16
+ model:
17
+ name: neus
18
+ radius: 1.0 ## check this
19
+ num_samples_per_ray: 1024
20
+ train_num_rays: 256
21
+ max_train_num_rays: 8192
22
+ grid_prune: true
23
+ grid_prune_occ_thre: 0.001
24
+ dynamic_ray_sampling: true
25
+ batch_image_sampling: true
26
+ randomized: true
27
+ ray_chunk: 2048
28
+ cos_anneal_end: 20000
29
+ learned_background: false
30
+ background_color: black
31
+ variance:
32
+ init_val: 0.3
33
+ modulate: false
34
+ geometry:
35
+ name: volume-sdf
36
+ radius: ${model.radius}
37
+ feature_dim: 13
38
+ grad_type: finite_difference
39
+ finite_difference_eps: progressive
40
+ isosurface:
41
+ method: mc
42
+ resolution: 384
43
+ chunk: 2097152
44
+ threshold: 0.
45
+ xyz_encoding_config:
46
+ otype: ProgressiveBandHashGrid
47
+ n_levels: 10 # 12 modify
48
+ n_features_per_level: 2
49
+ log2_hashmap_size: 19
50
+ base_resolution: 32
51
+ per_level_scale: 1.3195079107728942
52
+ include_xyz: true
53
+ start_level: 4
54
+ start_step: 0
55
+ update_steps: 1000
56
+ mlp_network_config:
57
+ otype: VanillaMLP
58
+ activation: ReLU
59
+ output_activation: none
60
+ n_neurons: 64
61
+ n_hidden_layers: 1
62
+ sphere_init: true
63
+ sphere_init_radius: 0.5
64
+ weight_norm: true
65
+ texture:
66
+ name: volume-radiance
67
+ input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
68
+ dir_encoding_config:
69
+ otype: SphericalHarmonics
70
+ degree: 4
71
+ mlp_network_config:
72
+ otype: VanillaMLP
73
+ activation: ReLU
74
+ output_activation: none
75
+ n_neurons: 64
76
+ n_hidden_layers: 2
77
+ color_activation: sigmoid
78
+
79
+ system:
80
+ name: videonvs-neus-system
81
+ loss:
82
+ lambda_rgb_mse: 0.5
83
+ lambda_rgb_l1: 0.
84
+ lambda_mask: 1.0
85
+ lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects
86
+ lambda_normal: 0.0 # cannot be too large
87
+ lambda_3d_normal_smooth: 1.0
88
+ # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
89
+ lambda_curvature: 0.
90
+ lambda_sparsity: 0.5
91
+ lambda_distortion: 0.0
92
+ lambda_distortion_bg: 0.0
93
+ lambda_opaque: 0.0
94
+ sparsity_scale: 100.0
95
+ geo_aware: true
96
+ rgb_p_ratio: 0.8
97
+ normal_p_ratio: 0.8
98
+ mask_p_ratio: 0.9
99
+ optimizer:
100
+ name: AdamW
101
+ args:
102
+ lr: 0.01
103
+ betas: [0.9, 0.99]
104
+ eps: 1.e-15
105
+ params:
106
+ geometry:
107
+ lr: 0.001
108
+ texture:
109
+ lr: 0.01
110
+ variance:
111
+ lr: 0.001
112
+ constant_steps: 500
113
+ scheduler:
114
+ name: SequentialLR
115
+ interval: step
116
+ milestones:
117
+ - ${system.constant_steps}
118
+ schedulers:
119
+ - name: ConstantLR
120
+ args:
121
+ factor: 1.0
122
+ total_iters: ${system.constant_steps}
123
+ - name: ExponentialLR
124
+ args:
125
+ gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
126
+
127
+ checkpoint:
128
+ save_top_k: -1
129
+ every_n_train_steps: ${trainer.max_steps}
130
+
131
+ export:
132
+ chunk_size: 2097152
133
+ export_vertex_color: True
134
+ ortho_scale: null #modify
135
+
136
+ trainer:
137
+ max_steps: 3000
138
+ log_every_n_steps: 100
139
+ num_sanity_val_steps: 0
140
+ val_check_interval: 3000
141
+ limit_train_batches: 1.0
142
+ limit_val_batches: 2
143
+ enable_progress_bar: true
144
+ precision: 16
mesh_recon/configs/videonvs.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ${basename:${dataset.scene}}
2
+ tag: ""
3
+ seed: 42
4
+
5
+ dataset:
6
+ name: videonvs
7
+ root_dir: ./spirals
8
+ cam_pose_dir: null
9
+ scene: pizza_man
10
+ apply_mask: true
11
+ train_split: train
12
+ test_split: train
13
+ val_split: train
14
+ img_wh: [1024, 1024]
15
+
16
+ model:
17
+ name: neus
18
+ radius: 1.0 ## check this
19
+ num_samples_per_ray: 1024
20
+ train_num_rays: 256
21
+ max_train_num_rays: 8192
22
+ grid_prune: true
23
+ grid_prune_occ_thre: 0.001
24
+ dynamic_ray_sampling: true
25
+ batch_image_sampling: true
26
+ randomized: true
27
+ ray_chunk: 2048
28
+ cos_anneal_end: 20000
29
+ learned_background: false
30
+ background_color: black
31
+ variance:
32
+ init_val: 0.3
33
+ modulate: false
34
+ geometry:
35
+ name: volume-sdf
36
+ radius: ${model.radius}
37
+ feature_dim: 13
38
+ grad_type: finite_difference
39
+ finite_difference_eps: progressive
40
+ isosurface:
41
+ method: mc
42
+ resolution: 384
43
+ chunk: 2097152
44
+ threshold: 0.
45
+ xyz_encoding_config:
46
+ otype: ProgressiveBandHashGrid
47
+ n_levels: 10 # 12 modify
48
+ n_features_per_level: 2
49
+ log2_hashmap_size: 19
50
+ base_resolution: 32
51
+ per_level_scale: 1.3195079107728942
52
+ include_xyz: true
53
+ start_level: 4
54
+ start_step: 0
55
+ update_steps: 1000
56
+ mlp_network_config:
57
+ otype: VanillaMLP
58
+ activation: ReLU
59
+ output_activation: none
60
+ n_neurons: 64
61
+ n_hidden_layers: 1
62
+ sphere_init: true
63
+ sphere_init_radius: 0.5
64
+ weight_norm: true
65
+ texture:
66
+ name: volume-radiance
67
+ input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
68
+ dir_encoding_config:
69
+ otype: SphericalHarmonics
70
+ degree: 4
71
+ mlp_network_config:
72
+ otype: VanillaMLP
73
+ activation: ReLU
74
+ output_activation: none
75
+ n_neurons: 64
76
+ n_hidden_layers: 2
77
+ color_activation: sigmoid
78
+
79
+ system:
80
+ name: videonvs-neus-system
81
+ loss:
82
+ lambda_rgb_mse: 0.5
83
+ lambda_rgb_l1: 0.
84
+ lambda_mask: 1.0
85
+ lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects
86
+ lambda_normal: 1.0 # cannot be too large
87
+ lambda_3d_normal_smooth: 1.0
88
+ # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
89
+ lambda_curvature: 0.
90
+ lambda_sparsity: 0.5
91
+ lambda_distortion: 0.0
92
+ lambda_distortion_bg: 0.0
93
+ lambda_opaque: 0.0
94
+ sparsity_scale: 100.0
95
+ geo_aware: true
96
+ rgb_p_ratio: 0.8
97
+ normal_p_ratio: 0.8
98
+ mask_p_ratio: 0.9
99
+ optimizer:
100
+ name: AdamW
101
+ args:
102
+ lr: 0.01
103
+ betas: [0.9, 0.99]
104
+ eps: 1.e-15
105
+ params:
106
+ geometry:
107
+ lr: 0.001
108
+ texture:
109
+ lr: 0.01
110
+ variance:
111
+ lr: 0.001
112
+ constant_steps: 500
113
+ scheduler:
114
+ name: SequentialLR
115
+ interval: step
116
+ milestones:
117
+ - ${system.constant_steps}
118
+ schedulers:
119
+ - name: ConstantLR
120
+ args:
121
+ factor: 1.0
122
+ total_iters: ${system.constant_steps}
123
+ - name: ExponentialLR
124
+ args:
125
+ gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
126
+
127
+ checkpoint:
128
+ save_top_k: -1
129
+ every_n_train_steps: ${trainer.max_steps}
130
+
131
+ export:
132
+ chunk_size: 2097152
133
+ export_vertex_color: True
134
+ ortho_scale: null #modify
135
+
136
+ trainer:
137
+ max_steps: 3000
138
+ log_every_n_steps: 100
139
+ num_sanity_val_steps: 0
140
+ val_check_interval: 3000
141
+ limit_train_batches: 1.0
142
+ limit_val_batches: 2
143
+ enable_progress_bar: true
144
+ precision: 16
mesh_recon/datasets/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets = {}
2
+
3
+
4
+ def register(name):
5
+ def decorator(cls):
6
+ datasets[name] = cls
7
+ return cls
8
+
9
+ return decorator
10
+
11
+
12
+ def make(name, config):
13
+ dataset = datasets[name](config)
14
+ return dataset
15
+
16
+
17
+ from . import blender, colmap, dtu, ortho, videonvs, videonvs_co3d, v3d
mesh_recon/datasets/blender.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
9
+ import torchvision.transforms.functional as TF
10
+
11
+ import pytorch_lightning as pl
12
+
13
+ import datasets
14
+ from models.ray_utils import get_ray_directions
15
+ from utils.misc import get_rank
16
+
17
+
18
+ class BlenderDatasetBase:
19
+ def setup(self, config, split):
20
+ self.config = config
21
+ self.split = split
22
+ self.rank = get_rank()
23
+
24
+ self.has_mask = True
25
+ self.apply_mask = True
26
+
27
+ with open(
28
+ os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), "r"
29
+ ) as f:
30
+ meta = json.load(f)
31
+
32
+ if "w" in meta and "h" in meta:
33
+ W, H = int(meta["w"]), int(meta["h"])
34
+ else:
35
+ W, H = 800, 800
36
+
37
+ if "img_wh" in self.config:
38
+ w, h = self.config.img_wh
39
+ assert round(W / w * h) == H
40
+ elif "img_downscale" in self.config:
41
+ w, h = W // self.config.img_downscale, H // self.config.img_downscale
42
+ else:
43
+ raise KeyError("Either img_wh or img_downscale should be specified.")
44
+
45
+ self.w, self.h = w, h
46
+ self.img_wh = (self.w, self.h)
47
+
48
+ self.near, self.far = self.config.near_plane, self.config.far_plane
49
+
50
+ self.focal = (
51
+ 0.5 * w / math.tan(0.5 * meta["camera_angle_x"])
52
+ ) # scaled focal length
53
+
54
+ # ray directions for all pixels, same for all images (same H, W, focal)
55
+ self.directions = get_ray_directions(
56
+ self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
57
+ ).to(
58
+ self.rank
59
+ ) # (h, w, 3)
60
+
61
+ self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
62
+
63
+ for i, frame in enumerate(meta["frames"]):
64
+ c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
65
+ self.all_c2w.append(c2w)
66
+
67
+ img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png")
68
+ img = Image.open(img_path)
69
+ img = img.resize(self.img_wh, Image.BICUBIC)
70
+ img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
71
+
72
+ self.all_fg_masks.append(img[..., -1]) # (h, w)
73
+ self.all_images.append(img[..., :3])
74
+
75
+ self.all_c2w, self.all_images, self.all_fg_masks = (
76
+ torch.stack(self.all_c2w, dim=0).float().to(self.rank),
77
+ torch.stack(self.all_images, dim=0).float().to(self.rank),
78
+ torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
79
+ )
80
+
81
+
82
+ class BlenderDataset(Dataset, BlenderDatasetBase):
83
+ def __init__(self, config, split):
84
+ self.setup(config, split)
85
+
86
+ def __len__(self):
87
+ return len(self.all_images)
88
+
89
+ def __getitem__(self, index):
90
+ return {"index": index}
91
+
92
+
93
+ class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
94
+ def __init__(self, config, split):
95
+ self.setup(config, split)
96
+
97
+ def __iter__(self):
98
+ while True:
99
+ yield {}
100
+
101
+
102
+ @datasets.register("blender")
103
+ class VideoNVSDataModule(pl.LightningDataModule):
104
+ def __init__(self, config):
105
+ super().__init__()
106
+ self.config = config
107
+
108
+ def setup(self, stage=None):
109
+ if stage in [None, "fit"]:
110
+ self.train_dataset = BlenderIterableDataset(
111
+ self.config, self.config.train_split
112
+ )
113
+ if stage in [None, "fit", "validate"]:
114
+ self.val_dataset = BlenderDataset(self.config, self.config.val_split)
115
+ if stage in [None, "test"]:
116
+ self.test_dataset = BlenderDataset(self.config, self.config.test_split)
117
+ if stage in [None, "predict"]:
118
+ self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
119
+
120
+ def prepare_data(self):
121
+ pass
122
+
123
+ def general_loader(self, dataset, batch_size):
124
+ sampler = None
125
+ return DataLoader(
126
+ dataset,
127
+ num_workers=os.cpu_count(),
128
+ batch_size=batch_size,
129
+ pin_memory=True,
130
+ sampler=sampler,
131
+ )
132
+
133
+ def train_dataloader(self):
134
+ return self.general_loader(self.train_dataset, batch_size=1)
135
+
136
+ def val_dataloader(self):
137
+ return self.general_loader(self.val_dataset, batch_size=1)
138
+
139
+ def test_dataloader(self):
140
+ return self.general_loader(self.test_dataset, batch_size=1)
141
+
142
+ def predict_dataloader(self):
143
+ return self.general_loader(self.predict_dataset, batch_size=1)
mesh_recon/datasets/colmap.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
9
+ import torchvision.transforms.functional as TF
10
+
11
+ import pytorch_lightning as pl
12
+
13
+ import datasets
14
+ from datasets.colmap_utils import \
15
+ read_cameras_binary, read_images_binary, read_points3d_binary
16
+ from models.ray_utils import get_ray_directions
17
+ from utils.misc import get_rank
18
+
19
+
20
+ def get_center(pts):
21
+ center = pts.mean(0)
22
+ dis = (pts - center[None,:]).norm(p=2, dim=-1)
23
+ mean, std = dis.mean(), dis.std()
24
+ q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75)
25
+ valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5)
26
+ center = pts[valid].mean(0)
27
+ return center
28
+
29
+ def normalize_poses(poses, pts, up_est_method, center_est_method):
30
+ if center_est_method == 'camera':
31
+ # estimation scene center as the average of all camera positions
32
+ center = poses[...,3].mean(0)
33
+ elif center_est_method == 'lookat':
34
+ # estimation scene center as the average of the intersection of selected pairs of camera rays
35
+ cams_ori = poses[...,3]
36
+ cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.])
37
+ cams_dir = F.normalize(cams_dir, dim=-1)
38
+ A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1)
39
+ b = -cams_ori + cams_ori.roll(1,0)
40
+ t = torch.linalg.lstsq(A, b).solution
41
+ center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2))
42
+ elif center_est_method == 'point':
43
+ # first estimation scene center as the average of all camera positions
44
+ # later we'll use the center of all points bounded by the cameras as the final scene center
45
+ center = poses[...,3].mean(0)
46
+ else:
47
+ raise NotImplementedError(f'Unknown center estimation method: {center_est_method}')
48
+
49
+ if up_est_method == 'ground':
50
+ # estimate up direction as the normal of the estimated ground plane
51
+ # use RANSAC to estimate the ground plane in the point cloud
52
+ import pyransac3d as pyrsc
53
+ ground = pyrsc.Plane()
54
+ plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale
55
+ plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0
56
+ z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction
57
+ signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1)
58
+ if signed_distance.mean() < 0:
59
+ z = -z # flip the direction if points lie under the plane
60
+ elif up_est_method == 'camera':
61
+ # estimate up direction as the average of all camera up directions
62
+ z = F.normalize((poses[...,3] - center).mean(0), dim=0)
63
+ else:
64
+ raise NotImplementedError(f'Unknown up estimation method: {up_est_method}')
65
+
66
+ # new axis
67
+ y_ = torch.as_tensor([z[1], -z[0], 0.])
68
+ x = F.normalize(y_.cross(z), dim=0)
69
+ y = z.cross(x)
70
+
71
+ if center_est_method == 'point':
72
+ # rotation
73
+ Rc = torch.stack([x, y, z], dim=1)
74
+ R = Rc.T
75
+ poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1)
76
+ inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
77
+ poses_norm = (inv_trans @ poses_homo)[:,:3]
78
+ pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
79
+
80
+ # translation and scaling
81
+ poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0]
82
+ pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])]
83
+ center = get_center(pts_fg)
84
+ tc = center.reshape(3, 1)
85
+ t = -tc
86
+ poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1)
87
+ inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
88
+ poses_norm = (inv_trans @ poses_homo)[:,:3]
89
+ scale = poses_norm[...,3].norm(p=2, dim=-1).min()
90
+ poses_norm[...,3] /= scale
91
+ pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
92
+ pts = pts / scale
93
+ else:
94
+ # rotation and translation
95
+ Rc = torch.stack([x, y, z], dim=1)
96
+ tc = center.reshape(3, 1)
97
+ R, t = Rc.T, -Rc.T @ tc
98
+ poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1)
99
+ inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
100
+ poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4)
101
+
102
+ # scaling
103
+ scale = poses_norm[...,3].norm(p=2, dim=-1).min()
104
+ poses_norm[...,3] /= scale
105
+
106
+ # apply the transformation to the point cloud
107
+ pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
108
+ pts = pts / scale
109
+
110
+ return poses_norm, pts
111
+
112
+ def create_spheric_poses(cameras, n_steps=120):
113
+ center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
114
+ mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean()
115
+ mean_h = cameras[:,2].mean()
116
+ r = (mean_d**2 - mean_h**2).sqrt()
117
+ up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device)
118
+
119
+ all_c2w = []
120
+ for theta in torch.linspace(0, 2 * math.pi, n_steps):
121
+ cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h])
122
+ l = F.normalize(center - cam_pos, p=2, dim=0)
123
+ s = F.normalize(l.cross(up), p=2, dim=0)
124
+ u = F.normalize(s.cross(l), p=2, dim=0)
125
+ c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
126
+ all_c2w.append(c2w)
127
+
128
+ all_c2w = torch.stack(all_c2w, dim=0)
129
+
130
+ return all_c2w
131
+
132
+ class ColmapDatasetBase():
133
+ # the data only has to be processed once
134
+ initialized = False
135
+ properties = {}
136
+
137
+ def setup(self, config, split):
138
+ self.config = config
139
+ self.split = split
140
+ self.rank = get_rank()
141
+
142
+ if not ColmapDatasetBase.initialized:
143
+ camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin'))
144
+
145
+ H = int(camdata[1].height)
146
+ W = int(camdata[1].width)
147
+
148
+ if 'img_wh' in self.config:
149
+ w, h = self.config.img_wh
150
+ assert round(W / w * h) == H
151
+ elif 'img_downscale' in self.config:
152
+ w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
153
+ else:
154
+ raise KeyError("Either img_wh or img_downscale should be specified.")
155
+
156
+ img_wh = (w, h)
157
+ factor = w / W
158
+
159
+ if camdata[1].model == 'SIMPLE_RADIAL':
160
+ fx = fy = camdata[1].params[0] * factor
161
+ cx = camdata[1].params[1] * factor
162
+ cy = camdata[1].params[2] * factor
163
+ elif camdata[1].model in ['PINHOLE', 'OPENCV']:
164
+ fx = camdata[1].params[0] * factor
165
+ fy = camdata[1].params[1] * factor
166
+ cx = camdata[1].params[2] * factor
167
+ cy = camdata[1].params[3] * factor
168
+ else:
169
+ raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!")
170
+
171
+ directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank)
172
+
173
+ imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin'))
174
+
175
+ mask_dir = os.path.join(self.config.root_dir, 'masks')
176
+ has_mask = os.path.exists(mask_dir) # TODO: support partial masks
177
+ apply_mask = has_mask and self.config.apply_mask
178
+
179
+ all_c2w, all_images, all_fg_masks = [], [], []
180
+
181
+ for i, d in enumerate(imdata.values()):
182
+ R = d.qvec2rotmat()
183
+ t = d.tvec.reshape(3, 1)
184
+ c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float()
185
+ c2w[:,1:3] *= -1. # COLMAP => OpenGL
186
+ all_c2w.append(c2w)
187
+ if self.split in ['train', 'val']:
188
+ img_path = os.path.join(self.config.root_dir, 'images', d.name)
189
+ img = Image.open(img_path)
190
+ img = img.resize(img_wh, Image.BICUBIC)
191
+ img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
192
+ img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu()
193
+ if has_mask:
194
+ mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])]
195
+ mask_paths = list(filter(os.path.exists, mask_paths))
196
+ assert len(mask_paths) == 1
197
+ mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1)
198
+ mask = mask.resize(img_wh, Image.BICUBIC)
199
+ mask = TF.to_tensor(mask)[0]
200
+ else:
201
+ mask = torch.ones_like(img[...,0], device=img.device)
202
+ all_fg_masks.append(mask) # (h, w)
203
+ all_images.append(img)
204
+
205
+ all_c2w = torch.stack(all_c2w, dim=0)
206
+
207
+ pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin'))
208
+ pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float()
209
+ all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method)
210
+
211
+ ColmapDatasetBase.properties = {
212
+ 'w': w,
213
+ 'h': h,
214
+ 'img_wh': img_wh,
215
+ 'factor': factor,
216
+ 'has_mask': has_mask,
217
+ 'apply_mask': apply_mask,
218
+ 'directions': directions,
219
+ 'pts3d': pts3d,
220
+ 'all_c2w': all_c2w,
221
+ 'all_images': all_images,
222
+ 'all_fg_masks': all_fg_masks
223
+ }
224
+
225
+ ColmapDatasetBase.initialized = True
226
+
227
+ for k, v in ColmapDatasetBase.properties.items():
228
+ setattr(self, k, v)
229
+
230
+ if self.split == 'test':
231
+ self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
232
+ self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
233
+ self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
234
+ else:
235
+ self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float()
236
+
237
+ """
238
+ # for debug use
239
+ from models.ray_utils import get_rays
240
+ rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True)
241
+ pts_out = []
242
+ pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()]))
243
+
244
+ t_vals = torch.linspace(0, 1, 8)
245
+ z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals
246
+
247
+ ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :])
248
+ pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()]))
249
+
250
+ ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :])
251
+ pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
252
+
253
+ ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :])
254
+ pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
255
+
256
+ ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :])
257
+ pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
258
+
259
+ open('cameras.txt', 'w').write('\n'.join(pts_out))
260
+ open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()]))
261
+
262
+ exit(1)
263
+ """
264
+
265
+ self.all_c2w = self.all_c2w.float().to(self.rank)
266
+ if self.config.load_data_on_gpu:
267
+ self.all_images = self.all_images.to(self.rank)
268
+ self.all_fg_masks = self.all_fg_masks.to(self.rank)
269
+
270
+
271
+ class ColmapDataset(Dataset, ColmapDatasetBase):
272
+ def __init__(self, config, split):
273
+ self.setup(config, split)
274
+
275
+ def __len__(self):
276
+ return len(self.all_images)
277
+
278
+ def __getitem__(self, index):
279
+ return {
280
+ 'index': index
281
+ }
282
+
283
+
284
+ class ColmapIterableDataset(IterableDataset, ColmapDatasetBase):
285
+ def __init__(self, config, split):
286
+ self.setup(config, split)
287
+
288
+ def __iter__(self):
289
+ while True:
290
+ yield {}
291
+
292
+
293
+ @datasets.register('colmap')
294
+ class ColmapDataModule(pl.LightningDataModule):
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ self.config = config
298
+
299
+ def setup(self, stage=None):
300
+ if stage in [None, 'fit']:
301
+ self.train_dataset = ColmapIterableDataset(self.config, 'train')
302
+ if stage in [None, 'fit', 'validate']:
303
+ self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train'))
304
+ if stage in [None, 'test']:
305
+ self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test'))
306
+ if stage in [None, 'predict']:
307
+ self.predict_dataset = ColmapDataset(self.config, 'train')
308
+
309
+ def prepare_data(self):
310
+ pass
311
+
312
+ def general_loader(self, dataset, batch_size):
313
+ sampler = None
314
+ return DataLoader(
315
+ dataset,
316
+ num_workers=os.cpu_count(),
317
+ batch_size=batch_size,
318
+ pin_memory=True,
319
+ sampler=sampler
320
+ )
321
+
322
+ def train_dataloader(self):
323
+ return self.general_loader(self.train_dataset, batch_size=1)
324
+
325
+ def val_dataloader(self):
326
+ return self.general_loader(self.val_dataset, batch_size=1)
327
+
328
+ def test_dataloader(self):
329
+ return self.general_loader(self.test_dataset, batch_size=1)
330
+
331
+ def predict_dataloader(self):
332
+ return self.general_loader(self.predict_dataset, batch_size=1)
mesh_recon/datasets/colmap_utils.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ #
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ #
10
+ # * Redistributions in binary form must reproduce the above copyright
11
+ # notice, this list of conditions and the following disclaimer in the
12
+ # documentation and/or other materials provided with the distribution.
13
+ #
14
+ # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15
+ # its contributors may be used to endorse or promote products derived
16
+ # from this software without specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
+ # POSSIBILITY OF SUCH DAMAGE.
29
+ #
30
+ # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)
31
+
32
+ import os
33
+ import collections
34
+ import numpy as np
35
+ import struct
36
+
37
+
38
+ CameraModel = collections.namedtuple(
39
+ "CameraModel", ["model_id", "model_name", "num_params"])
40
+ Camera = collections.namedtuple(
41
+ "Camera", ["id", "model", "width", "height", "params"])
42
+ BaseImage = collections.namedtuple(
43
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
44
+ Point3D = collections.namedtuple(
45
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
46
+
47
+ class Image(BaseImage):
48
+ def qvec2rotmat(self):
49
+ return qvec2rotmat(self.qvec)
50
+
51
+
52
+ CAMERA_MODELS = {
53
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
54
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
55
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
56
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
57
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
58
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
59
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
60
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
61
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
62
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
63
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
64
+ }
65
+ CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \
66
+ for camera_model in CAMERA_MODELS])
67
+
68
+
69
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
70
+ """Read and unpack the next bytes from a binary file.
71
+ :param fid:
72
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
73
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
74
+ :param endian_character: Any of {@, =, <, >, !}
75
+ :return: Tuple of read and unpacked values.
76
+ """
77
+ data = fid.read(num_bytes)
78
+ return struct.unpack(endian_character + format_char_sequence, data)
79
+
80
+
81
+ def read_cameras_text(path):
82
+ """
83
+ see: src/base/reconstruction.cc
84
+ void Reconstruction::WriteCamerasText(const std::string& path)
85
+ void Reconstruction::ReadCamerasText(const std::string& path)
86
+ """
87
+ cameras = {}
88
+ with open(path, "r") as fid:
89
+ while True:
90
+ line = fid.readline()
91
+ if not line:
92
+ break
93
+ line = line.strip()
94
+ if len(line) > 0 and line[0] != "#":
95
+ elems = line.split()
96
+ camera_id = int(elems[0])
97
+ model = elems[1]
98
+ width = int(elems[2])
99
+ height = int(elems[3])
100
+ params = np.array(tuple(map(float, elems[4:])))
101
+ cameras[camera_id] = Camera(id=camera_id, model=model,
102
+ width=width, height=height,
103
+ params=params)
104
+ return cameras
105
+
106
+
107
+ def read_cameras_binary(path_to_model_file):
108
+ """
109
+ see: src/base/reconstruction.cc
110
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
111
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
112
+ """
113
+ cameras = {}
114
+ with open(path_to_model_file, "rb") as fid:
115
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
116
+ for camera_line_index in range(num_cameras):
117
+ camera_properties = read_next_bytes(
118
+ fid, num_bytes=24, format_char_sequence="iiQQ")
119
+ camera_id = camera_properties[0]
120
+ model_id = camera_properties[1]
121
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
122
+ width = camera_properties[2]
123
+ height = camera_properties[3]
124
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
125
+ params = read_next_bytes(fid, num_bytes=8*num_params,
126
+ format_char_sequence="d"*num_params)
127
+ cameras[camera_id] = Camera(id=camera_id,
128
+ model=model_name,
129
+ width=width,
130
+ height=height,
131
+ params=np.array(params))
132
+ assert len(cameras) == num_cameras
133
+ return cameras
134
+
135
+
136
+ def read_images_text(path):
137
+ """
138
+ see: src/base/reconstruction.cc
139
+ void Reconstruction::ReadImagesText(const std::string& path)
140
+ void Reconstruction::WriteImagesText(const std::string& path)
141
+ """
142
+ images = {}
143
+ with open(path, "r") as fid:
144
+ while True:
145
+ line = fid.readline()
146
+ if not line:
147
+ break
148
+ line = line.strip()
149
+ if len(line) > 0 and line[0] != "#":
150
+ elems = line.split()
151
+ image_id = int(elems[0])
152
+ qvec = np.array(tuple(map(float, elems[1:5])))
153
+ tvec = np.array(tuple(map(float, elems[5:8])))
154
+ camera_id = int(elems[8])
155
+ image_name = elems[9]
156
+ elems = fid.readline().split()
157
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
158
+ tuple(map(float, elems[1::3]))])
159
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
160
+ images[image_id] = Image(
161
+ id=image_id, qvec=qvec, tvec=tvec,
162
+ camera_id=camera_id, name=image_name,
163
+ xys=xys, point3D_ids=point3D_ids)
164
+ return images
165
+
166
+
167
+ def read_images_binary(path_to_model_file):
168
+ """
169
+ see: src/base/reconstruction.cc
170
+ void Reconstruction::ReadImagesBinary(const std::string& path)
171
+ void Reconstruction::WriteImagesBinary(const std::string& path)
172
+ """
173
+ images = {}
174
+ with open(path_to_model_file, "rb") as fid:
175
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
176
+ for image_index in range(num_reg_images):
177
+ binary_image_properties = read_next_bytes(
178
+ fid, num_bytes=64, format_char_sequence="idddddddi")
179
+ image_id = binary_image_properties[0]
180
+ qvec = np.array(binary_image_properties[1:5])
181
+ tvec = np.array(binary_image_properties[5:8])
182
+ camera_id = binary_image_properties[8]
183
+ image_name = ""
184
+ current_char = read_next_bytes(fid, 1, "c")[0]
185
+ while current_char != b"\x00": # look for the ASCII 0 entry
186
+ image_name += current_char.decode("utf-8")
187
+ current_char = read_next_bytes(fid, 1, "c")[0]
188
+ num_points2D = read_next_bytes(fid, num_bytes=8,
189
+ format_char_sequence="Q")[0]
190
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
191
+ format_char_sequence="ddq"*num_points2D)
192
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
193
+ tuple(map(float, x_y_id_s[1::3]))])
194
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
195
+ images[image_id] = Image(
196
+ id=image_id, qvec=qvec, tvec=tvec,
197
+ camera_id=camera_id, name=image_name,
198
+ xys=xys, point3D_ids=point3D_ids)
199
+ return images
200
+
201
+
202
+ def read_points3D_text(path):
203
+ """
204
+ see: src/base/reconstruction.cc
205
+ void Reconstruction::ReadPoints3DText(const std::string& path)
206
+ void Reconstruction::WritePoints3DText(const std::string& path)
207
+ """
208
+ points3D = {}
209
+ with open(path, "r") as fid:
210
+ while True:
211
+ line = fid.readline()
212
+ if not line:
213
+ break
214
+ line = line.strip()
215
+ if len(line) > 0 and line[0] != "#":
216
+ elems = line.split()
217
+ point3D_id = int(elems[0])
218
+ xyz = np.array(tuple(map(float, elems[1:4])))
219
+ rgb = np.array(tuple(map(int, elems[4:7])))
220
+ error = float(elems[7])
221
+ image_ids = np.array(tuple(map(int, elems[8::2])))
222
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
223
+ points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
224
+ error=error, image_ids=image_ids,
225
+ point2D_idxs=point2D_idxs)
226
+ return points3D
227
+
228
+
229
+ def read_points3d_binary(path_to_model_file):
230
+ """
231
+ see: src/base/reconstruction.cc
232
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
233
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
234
+ """
235
+ points3D = {}
236
+ with open(path_to_model_file, "rb") as fid:
237
+ num_points = read_next_bytes(fid, 8, "Q")[0]
238
+ for point_line_index in range(num_points):
239
+ binary_point_line_properties = read_next_bytes(
240
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
241
+ point3D_id = binary_point_line_properties[0]
242
+ xyz = np.array(binary_point_line_properties[1:4])
243
+ rgb = np.array(binary_point_line_properties[4:7])
244
+ error = np.array(binary_point_line_properties[7])
245
+ track_length = read_next_bytes(
246
+ fid, num_bytes=8, format_char_sequence="Q")[0]
247
+ track_elems = read_next_bytes(
248
+ fid, num_bytes=8*track_length,
249
+ format_char_sequence="ii"*track_length)
250
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
251
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
252
+ points3D[point3D_id] = Point3D(
253
+ id=point3D_id, xyz=xyz, rgb=rgb,
254
+ error=error, image_ids=image_ids,
255
+ point2D_idxs=point2D_idxs)
256
+ return points3D
257
+
258
+
259
+ def read_model(path, ext):
260
+ if ext == ".txt":
261
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
262
+ images = read_images_text(os.path.join(path, "images" + ext))
263
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
264
+ else:
265
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
266
+ images = read_images_binary(os.path.join(path, "images" + ext))
267
+ points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
268
+ return cameras, images, points3D
269
+
270
+
271
+ def qvec2rotmat(qvec):
272
+ return np.array([
273
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
274
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
275
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
276
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
277
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
278
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
279
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
280
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
281
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
282
+
283
+
284
+ def rotmat2qvec(R):
285
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
286
+ K = np.array([
287
+ [Rxx - Ryy - Rzz, 0, 0, 0],
288
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
289
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
290
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
291
+ eigvals, eigvecs = np.linalg.eigh(K)
292
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
293
+ if qvec[0] < 0:
294
+ qvec *= -1
295
+ return qvec
mesh_recon/datasets/dtu.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
11
+ import torchvision.transforms.functional as TF
12
+
13
+ import pytorch_lightning as pl
14
+
15
+ import datasets
16
+ from models.ray_utils import get_ray_directions
17
+ from utils.misc import get_rank
18
+
19
+
20
+ def load_K_Rt_from_P(P=None):
21
+ out = cv2.decomposeProjectionMatrix(P)
22
+ K = out[0]
23
+ R = out[1]
24
+ t = out[2]
25
+
26
+ K = K / K[2, 2]
27
+ intrinsics = np.eye(4)
28
+ intrinsics[:3, :3] = K
29
+
30
+ pose = np.eye(4, dtype=np.float32)
31
+ pose[:3, :3] = R.transpose()
32
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
33
+
34
+ return intrinsics, pose
35
+
36
+ def create_spheric_poses(cameras, n_steps=120):
37
+ center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
38
+ cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2)
39
+ eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors
40
+ rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1)
41
+ up = rot_axis
42
+ rot_dir = torch.cross(rot_axis, cam_center)
43
+ max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max()
44
+
45
+ all_c2w = []
46
+ for theta in torch.linspace(-max_angle, max_angle, n_steps):
47
+ cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta)
48
+ l = F.normalize(center - cam_pos, p=2, dim=0)
49
+ s = F.normalize(l.cross(up), p=2, dim=0)
50
+ u = F.normalize(s.cross(l), p=2, dim=0)
51
+ c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
52
+ all_c2w.append(c2w)
53
+
54
+ all_c2w = torch.stack(all_c2w, dim=0)
55
+
56
+ return all_c2w
57
+
58
+ class DTUDatasetBase():
59
+ def setup(self, config, split):
60
+ self.config = config
61
+ self.split = split
62
+ self.rank = get_rank()
63
+
64
+ cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file))
65
+
66
+ img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png'))
67
+ H, W = img_sample.shape[0], img_sample.shape[1]
68
+
69
+ if 'img_wh' in self.config:
70
+ w, h = self.config.img_wh
71
+ assert round(W / w * h) == H
72
+ elif 'img_downscale' in self.config:
73
+ w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
74
+ else:
75
+ raise KeyError("Either img_wh or img_downscale should be specified.")
76
+
77
+ self.w, self.h = w, h
78
+ self.img_wh = (w, h)
79
+ self.factor = w / W
80
+
81
+ mask_dir = os.path.join(self.config.root_dir, 'mask')
82
+ self.has_mask = True
83
+ self.apply_mask = self.config.apply_mask
84
+
85
+ self.directions = []
86
+ self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
87
+
88
+ n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1
89
+
90
+ for i in range(n_images):
91
+ world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}']
92
+ P = (world_mat @ scale_mat)[:3,:4]
93
+ K, c2w = load_K_Rt_from_P(P)
94
+ fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor
95
+ directions = get_ray_directions(w, h, fx, fy, cx, cy)
96
+ self.directions.append(directions)
97
+
98
+ c2w = torch.from_numpy(c2w).float()
99
+
100
+ # blender follows opengl camera coordinates (right up back)
101
+ # NeuS DTU data coordinate system (right down front) is different from blender
102
+ # https://github.com/Totoro97/NeuS/issues/9
103
+ # for c2w, flip the sign of input camera coordinate yz
104
+ c2w_ = c2w.clone()
105
+ c2w_[:3,1:3] *= -1. # flip input sign
106
+ self.all_c2w.append(c2w_[:3,:4])
107
+
108
+ if self.split in ['train', 'val']:
109
+ img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png')
110
+ img = Image.open(img_path)
111
+ img = img.resize(self.img_wh, Image.BICUBIC)
112
+ img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
113
+
114
+ mask_path = os.path.join(mask_dir, f'{i:03d}.png')
115
+ mask = Image.open(mask_path).convert('L') # (H, W, 1)
116
+ mask = mask.resize(self.img_wh, Image.BICUBIC)
117
+ mask = TF.to_tensor(mask)[0]
118
+
119
+ self.all_fg_masks.append(mask) # (h, w)
120
+ self.all_images.append(img)
121
+
122
+ self.all_c2w = torch.stack(self.all_c2w, dim=0)
123
+
124
+ if self.split == 'test':
125
+ self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
126
+ self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
127
+ self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
128
+ self.directions = self.directions[0]
129
+ else:
130
+ self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0)
131
+ self.directions = torch.stack(self.directions, dim=0)
132
+
133
+ self.directions = self.directions.float().to(self.rank)
134
+ self.all_c2w, self.all_images, self.all_fg_masks = \
135
+ self.all_c2w.float().to(self.rank), \
136
+ self.all_images.float().to(self.rank), \
137
+ self.all_fg_masks.float().to(self.rank)
138
+
139
+
140
+ class DTUDataset(Dataset, DTUDatasetBase):
141
+ def __init__(self, config, split):
142
+ self.setup(config, split)
143
+
144
+ def __len__(self):
145
+ return len(self.all_images)
146
+
147
+ def __getitem__(self, index):
148
+ return {
149
+ 'index': index
150
+ }
151
+
152
+
153
+ class DTUIterableDataset(IterableDataset, DTUDatasetBase):
154
+ def __init__(self, config, split):
155
+ self.setup(config, split)
156
+
157
+ def __iter__(self):
158
+ while True:
159
+ yield {}
160
+
161
+
162
+ @datasets.register('dtu')
163
+ class DTUDataModule(pl.LightningDataModule):
164
+ def __init__(self, config):
165
+ super().__init__()
166
+ self.config = config
167
+
168
+ def setup(self, stage=None):
169
+ if stage in [None, 'fit']:
170
+ self.train_dataset = DTUIterableDataset(self.config, 'train')
171
+ if stage in [None, 'fit', 'validate']:
172
+ self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train'))
173
+ if stage in [None, 'test']:
174
+ self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test'))
175
+ if stage in [None, 'predict']:
176
+ self.predict_dataset = DTUDataset(self.config, 'train')
177
+
178
+ def prepare_data(self):
179
+ pass
180
+
181
+ def general_loader(self, dataset, batch_size):
182
+ sampler = None
183
+ return DataLoader(
184
+ dataset,
185
+ num_workers=os.cpu_count(),
186
+ batch_size=batch_size,
187
+ pin_memory=True,
188
+ sampler=sampler
189
+ )
190
+
191
+ def train_dataloader(self):
192
+ return self.general_loader(self.train_dataset, batch_size=1)
193
+
194
+ def val_dataloader(self):
195
+ return self.general_loader(self.val_dataset, batch_size=1)
196
+
197
+ def test_dataloader(self):
198
+ return self.general_loader(self.test_dataset, batch_size=1)
199
+
200
+ def predict_dataloader(self):
201
+ return self.general_loader(self.predict_dataset, batch_size=1)
mesh_recon/datasets/fixed_poses/000_back_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -1.000000238418579102e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
2
+ 0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 1.746665105883948854e-07
3
+ 0.000000000000000000e+00 1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
mesh_recon/datasets/fixed_poses/000_back_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -7.071069478988647461e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07
2
+ 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08
3
+ -7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
mesh_recon/datasets/fixed_poses/000_back_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -7.071069478988647461e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07
2
+ 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08
3
+ 7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
mesh_recon/datasets/fixed_poses/000_front_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
2
+ 0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 -1.746665105883948854e-07
3
+ 0.000000000000000000e+00 -1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
mesh_recon/datasets/fixed_poses/000_front_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 7.071067690849304199e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07
2
+ 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08
3
+ -7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
mesh_recon/datasets/fixed_poses/000_front_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 7.071067690849304199e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07
2
+ 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08
3
+ 7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
mesh_recon/datasets/fixed_poses/000_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -2.220446049250313081e-16 -1.000000000000000000e+00 0.000000000000000000e+00 -2.886579758146288598e-16
2
+ 0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
3
+ -1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
mesh_recon/datasets/fixed_poses/000_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 2.886579758146288598e-16
2
+ 0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
3
+ 1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
mesh_recon/datasets/fixed_poses/000_top_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
2
+ 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
3
+ 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 -1.299999952316284180e+00
mesh_recon/datasets/ortho.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
11
+ import torchvision.transforms.functional as TF
12
+
13
+ import pytorch_lightning as pl
14
+
15
+ import datasets
16
+ from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions
17
+ from utils.misc import get_rank
18
+
19
+ from glob import glob
20
+ import PIL.Image
21
+
22
+
23
+ def camNormal2worldNormal(rot_c2w, camNormal):
24
+ H,W,_ = camNormal.shape
25
+ normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
26
+
27
+ return normal_img
28
+
29
+ def worldNormal2camNormal(rot_w2c, worldNormal):
30
+ H,W,_ = worldNormal.shape
31
+ normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
32
+
33
+ return normal_img
34
+
35
+ def trans_normal(normal, RT_w2c, RT_w2c_target):
36
+
37
+ normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
38
+ normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
39
+
40
+ return normal_target_cam
41
+
42
+ def img2normal(img):
43
+ return (img/255.)*2-1
44
+
45
+ def normal2img(normal):
46
+ return np.uint8((normal*0.5+0.5)*255)
47
+
48
+ def norm_normalize(normal, dim=-1):
49
+
50
+ normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
51
+
52
+ return normal
53
+
54
+ def RT_opengl2opencv(RT):
55
+ # Build the coordinate transform matrix from world to computer vision camera
56
+ # R_world2cv = R_bcam2cv@R_world2bcam
57
+ # T_world2cv = R_bcam2cv@T_world2bcam
58
+
59
+ R = RT[:3, :3]
60
+ t = RT[:3, 3]
61
+
62
+ R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
63
+
64
+ R_world2cv = R_bcam2cv @ R
65
+ t_world2cv = R_bcam2cv @ t
66
+
67
+ RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
68
+
69
+ return RT
70
+
71
+ def normal_opengl2opencv(normal):
72
+ H,W,C = np.shape(normal)
73
+ # normal_img = np.reshape(normal, (H*W,C))
74
+ R_bcam2cv = np.array([1, -1, -1], np.float32)
75
+ normal_cv = normal * R_bcam2cv[None, None, :]
76
+
77
+ print(np.shape(normal_cv))
78
+
79
+ return normal_cv
80
+
81
+ def inv_RT(RT):
82
+ RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0)
83
+ RT_inv = np.linalg.inv(RT_h)
84
+
85
+ return RT_inv[:3, :]
86
+
87
+
88
+ def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None,
89
+ normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None):
90
+
91
+ all_images = []
92
+ all_normals = []
93
+ all_normals_world = []
94
+ all_masks = []
95
+ all_color_masks = []
96
+ all_poses = []
97
+ all_w2cs = []
98
+ directions = []
99
+ ray_origins = []
100
+
101
+ RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0]) # world2cam matrix
102
+ RT_front_cv = RT_opengl2opencv(RT_front) # convert normal from opengl to opencv
103
+ for idx, view in enumerate(view_types):
104
+ print(os.path.join(root_dir,test_object))
105
+ normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view))
106
+ # Load key frame
107
+ if load_color: # use bgr
108
+ image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3]
109
+
110
+ normal = np.array(PIL.Image.open(normal_filepath).resize(imSize))
111
+ mask = normal[:, :, 3]
112
+ normal = normal[:, :, :3]
113
+
114
+ color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3]
115
+ invalid_color_mask = color_mask < 255*0.5
116
+ threshold = np.ones_like(image[:, :, 0]) * 250
117
+ invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold)
118
+ invalid_color_mask_final = invalid_color_mask & invalid_white_mask
119
+ color_mask = (1 - invalid_color_mask_final) > 0
120
+
121
+ # if erode_mask:
122
+ # kernel = np.ones((3, 3), np.uint8)
123
+ # mask = cv2.erode(mask, kernel, iterations=1)
124
+
125
+ RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view))) # world2cam matrix
126
+
127
+ normal = img2normal(normal)
128
+
129
+ normal[mask==0] = [0,0,0]
130
+ mask = mask> (0.5*255)
131
+ if load_color:
132
+ all_images.append(image)
133
+
134
+ all_masks.append(mask)
135
+ all_color_masks.append(color_mask)
136
+ RT_cv = RT_opengl2opencv(RT) # convert normal from opengl to opencv
137
+ all_poses.append(inv_RT(RT_cv)) # cam2world
138
+ all_w2cs.append(RT_cv)
139
+
140
+ # whether to
141
+ normal_cam_cv = normal_opengl2opencv(normal)
142
+
143
+ if normal_system == 'front':
144
+ print("the loaded normals are defined in the system of front view")
145
+ normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv)
146
+ elif normal_system == 'self':
147
+ print("the loaded normals are in their independent camera systems")
148
+ normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv)
149
+ all_normals.append(normal_cam_cv)
150
+ all_normals_world.append(normal_world)
151
+
152
+ if camera_type == 'ortho':
153
+ origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1])
154
+ elif camera_type == 'pinhole':
155
+ dirs = get_ray_directions(W=imSize[0], H=imSize[1],
156
+ fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3])
157
+ origins = dirs # occupy a position
158
+ else:
159
+ raise Exception("not support camera type")
160
+ ray_origins.append(origins)
161
+ directions.append(dirs)
162
+
163
+
164
+ if not load_color:
165
+ all_images = [normal2img(x) for x in all_normals_world]
166
+
167
+
168
+ return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \
169
+ np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks)
170
+
171
+
172
+ class OrthoDatasetBase():
173
+ def setup(self, config, split):
174
+ self.config = config
175
+ self.split = split
176
+ self.rank = get_rank()
177
+
178
+ self.data_dir = self.config.root_dir
179
+ self.object_name = self.config.scene
180
+ self.scene = self.config.scene
181
+ self.imSize = self.config.imSize
182
+ self.load_color = True
183
+ self.img_wh = [self.imSize[0], self.imSize[1]]
184
+ self.w = self.img_wh[0]
185
+ self.h = self.img_wh[1]
186
+ self.camera_type = self.config.camera_type
187
+ self.camera_params = self.config.camera_params # [fx, fy, cx, cy]
188
+
189
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
190
+
191
+ self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1)
192
+ self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w)
193
+
194
+ if self.config.cam_pose_dir is None:
195
+ self.cam_pose_dir = "./datasets/fixed_poses"
196
+ else:
197
+ self.cam_pose_dir = self.config.cam_pose_dir
198
+
199
+ self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \
200
+ self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction(
201
+ self.data_dir, self.object_name, self.imSize, self.view_types,
202
+ self.load_color, self.cam_pose_dir, normal_system='front',
203
+ camera_type=self.camera_type, cam_params=self.camera_params)
204
+
205
+ self.has_mask = True
206
+ self.apply_mask = self.config.apply_mask
207
+
208
+ self.all_c2w = torch.from_numpy(self.pose_all_np)
209
+ self.all_images = torch.from_numpy(self.images_np) / 255.
210
+ self.all_fg_masks = torch.from_numpy(self.masks_np)
211
+ self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np)
212
+ self.all_normals_world = torch.from_numpy(self.normals_world_np)
213
+ self.origins = torch.from_numpy(self.origins_np)
214
+ self.directions = torch.from_numpy(self.directions_np)
215
+
216
+ self.directions = self.directions.float().to(self.rank)
217
+ self.origins = self.origins.float().to(self.rank)
218
+ self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank)
219
+ self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \
220
+ self.all_c2w.float().to(self.rank), \
221
+ self.all_images.float().to(self.rank), \
222
+ self.all_fg_masks.float().to(self.rank), \
223
+ self.all_normals_world.float().to(self.rank)
224
+
225
+
226
+ class OrthoDataset(Dataset, OrthoDatasetBase):
227
+ def __init__(self, config, split):
228
+ self.setup(config, split)
229
+
230
+ def __len__(self):
231
+ return len(self.all_images)
232
+
233
+ def __getitem__(self, index):
234
+ return {
235
+ 'index': index
236
+ }
237
+
238
+
239
+ class OrthoIterableDataset(IterableDataset, OrthoDatasetBase):
240
+ def __init__(self, config, split):
241
+ self.setup(config, split)
242
+
243
+ def __iter__(self):
244
+ while True:
245
+ yield {}
246
+
247
+
248
+ @datasets.register('ortho')
249
+ class OrthoDataModule(pl.LightningDataModule):
250
+ def __init__(self, config):
251
+ super().__init__()
252
+ self.config = config
253
+
254
+ def setup(self, stage=None):
255
+ if stage in [None, 'fit']:
256
+ self.train_dataset = OrthoIterableDataset(self.config, 'train')
257
+ if stage in [None, 'fit', 'validate']:
258
+ self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train'))
259
+ if stage in [None, 'test']:
260
+ self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test'))
261
+ if stage in [None, 'predict']:
262
+ self.predict_dataset = OrthoDataset(self.config, 'train')
263
+
264
+ def prepare_data(self):
265
+ pass
266
+
267
+ def general_loader(self, dataset, batch_size):
268
+ sampler = None
269
+ return DataLoader(
270
+ dataset,
271
+ num_workers=os.cpu_count(),
272
+ batch_size=batch_size,
273
+ pin_memory=True,
274
+ sampler=sampler
275
+ )
276
+
277
+ def train_dataloader(self):
278
+ return self.general_loader(self.train_dataset, batch_size=1)
279
+
280
+ def val_dataloader(self):
281
+ return self.general_loader(self.val_dataset, batch_size=1)
282
+
283
+ def test_dataloader(self):
284
+ return self.general_loader(self.test_dataset, batch_size=1)
285
+
286
+ def predict_dataloader(self):
287
+ return self.general_loader(self.predict_dataset, batch_size=1)
mesh_recon/datasets/utils.py ADDED
File without changes
mesh_recon/datasets/v3d.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
9
+ import torchvision.transforms.functional as TF
10
+ from torchvision.utils import make_grid, save_image
11
+ from einops import rearrange
12
+ from mediapy import read_video
13
+ from pathlib import Path
14
+ from rembg import remove, new_session
15
+
16
+ import pytorch_lightning as pl
17
+
18
+ import datasets
19
+ from models.ray_utils import get_ray_directions
20
+ from utils.misc import get_rank
21
+ from datasets.ortho import (
22
+ inv_RT,
23
+ camNormal2worldNormal,
24
+ RT_opengl2opencv,
25
+ normal_opengl2opencv,
26
+ )
27
+ from utils.dpt import DPT
28
+
29
+
30
+ def get_c2w_from_up_and_look_at(
31
+ up,
32
+ look_at,
33
+ pos,
34
+ opengl=False,
35
+ ):
36
+ up = up / np.linalg.norm(up)
37
+ z = look_at - pos
38
+ z = z / np.linalg.norm(z)
39
+ y = -up
40
+ x = np.cross(y, z)
41
+ x /= np.linalg.norm(x)
42
+ y = np.cross(z, x)
43
+
44
+ c2w = np.zeros([4, 4], dtype=np.float32)
45
+ c2w[:3, 0] = x
46
+ c2w[:3, 1] = y
47
+ c2w[:3, 2] = z
48
+ c2w[:3, 3] = pos
49
+ c2w[3, 3] = 1.0
50
+
51
+ # opencv to opengl
52
+ if opengl:
53
+ c2w[..., 1:3] *= -1
54
+
55
+ return c2w
56
+
57
+
58
+ def get_uniform_poses(num_frames, radius, elevation, opengl=False):
59
+ T = num_frames
60
+ azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T])
61
+ elevations = np.full_like(azimuths, np.deg2rad(elevation))
62
+ cam_dists = np.full_like(azimuths, radius)
63
+
64
+ campos = np.stack(
65
+ [
66
+ cam_dists * np.cos(elevations) * np.cos(azimuths),
67
+ cam_dists * np.cos(elevations) * np.sin(azimuths),
68
+ cam_dists * np.sin(elevations),
69
+ ],
70
+ axis=-1,
71
+ )
72
+
73
+ center = np.array([0, 0, 0], dtype=np.float32)
74
+ up = np.array([0, 0, 1], dtype=np.float32)
75
+ poses = []
76
+ for t in range(T):
77
+ poses.append(get_c2w_from_up_and_look_at(up, center, campos[t], opengl=opengl))
78
+
79
+ return np.stack(poses, axis=0)
80
+
81
+
82
+ def blender2midas(img):
83
+ """Blender: rub
84
+ midas: lub
85
+ """
86
+ img[..., 0] = -img[..., 0]
87
+ img[..., 1] = -img[..., 1]
88
+ img[..., -1] = -img[..., -1]
89
+ return img
90
+
91
+
92
+ def midas2blender(img):
93
+ """Blender: rub
94
+ midas: lub
95
+ """
96
+ img[..., 0] = -img[..., 0]
97
+ img[..., 1] = -img[..., 1]
98
+ img[..., -1] = -img[..., -1]
99
+ return img
100
+
101
+
102
+ class BlenderDatasetBase:
103
+ def setup(self, config, split):
104
+ self.config = config
105
+ self.rank = get_rank()
106
+
107
+ self.has_mask = True
108
+ self.apply_mask = True
109
+
110
+ dpt = DPT(device=self.rank, mode="normal")
111
+
112
+ # with open(
113
+ # os.path.join(
114
+ # self.config.root_dir, self.config.scene, f"transforms_train.json"
115
+ # ),
116
+ # "r",
117
+ # ) as f:
118
+ # meta = json.load(f)
119
+
120
+ # if "w" in meta and "h" in meta:
121
+ # W, H = int(meta["w"]), int(meta["h"])
122
+ # else:
123
+ # W, H = 800, 800
124
+ frames = read_video(Path(self.config.root_dir) / f"{self.config.scene}")
125
+ rembg_session = new_session()
126
+ num_frames, H, W = frames.shape[:3]
127
+
128
+ if "img_wh" in self.config:
129
+ w, h = self.config.img_wh
130
+ assert round(W / w * h) == H
131
+ elif "img_downscale" in self.config:
132
+ w, h = W // self.config.img_downscale, H // self.config.img_downscale
133
+ else:
134
+ raise KeyError("Either img_wh or img_downscale should be specified.")
135
+
136
+ self.w, self.h = w, h
137
+ self.img_wh = (self.w, self.h)
138
+
139
+ # self.near, self.far = self.config.near_plane, self.config.far_plane
140
+
141
+ self.focal = 0.5 * w / math.tan(0.5 * np.deg2rad(60)) # scaled focal length
142
+
143
+ # ray directions for all pixels, same for all images (same H, W, focal)
144
+ self.directions = get_ray_directions(
145
+ self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
146
+ ).to(
147
+ self.rank
148
+ ) # (h, w, 3)
149
+
150
+ self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
151
+
152
+ radius = 2.0
153
+ elevation = 0.0
154
+ poses = get_uniform_poses(num_frames, radius, elevation, opengl=True)
155
+ for i, (c2w, frame) in enumerate(zip(poses, frames)):
156
+ c2w = torch.from_numpy(np.array(c2w)[:3, :4])
157
+ self.all_c2w.append(c2w)
158
+
159
+ img = Image.fromarray(frame)
160
+ img = remove(img, session=rembg_session)
161
+ img = img.resize(self.img_wh, Image.BICUBIC)
162
+ img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
163
+
164
+ self.all_fg_masks.append(img[..., -1]) # (h, w)
165
+ self.all_images.append(img[..., :3])
166
+
167
+ self.all_c2w, self.all_images, self.all_fg_masks = (
168
+ torch.stack(self.all_c2w, dim=0).float().to(self.rank),
169
+ torch.stack(self.all_images, dim=0).float().to(self.rank),
170
+ torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
171
+ )
172
+
173
+ self.normals = dpt(self.all_images)
174
+
175
+ self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
176
+
177
+ self.normals = self.normals * 2.0 - 1.0
178
+ self.normals = midas2blender(self.normals).cpu().numpy()
179
+ # self.normals = self.normals.cpu().numpy()
180
+ self.normals[..., 0] *= -1
181
+ self.normals[~self.all_masks] = [0, 0, 0]
182
+ normals = rearrange(self.normals, "b h w c -> b c h w")
183
+ normals = normals * 0.5 + 0.5
184
+ normals = torch.from_numpy(normals)
185
+ # save_image(make_grid(normals, nrow=4), "tmp/normals.png")
186
+ # exit(0)
187
+
188
+ (
189
+ self.all_poses,
190
+ self.all_normals,
191
+ self.all_normals_world,
192
+ self.all_w2cs,
193
+ self.all_color_masks,
194
+ ) = ([], [], [], [], [])
195
+
196
+ for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
197
+ RT_opengl = inv_RT(c2w_opengl)
198
+ RT_opencv = RT_opengl2opencv(RT_opengl)
199
+ c2w_opencv = inv_RT(RT_opencv)
200
+ self.all_poses.append(c2w_opencv)
201
+ self.all_w2cs.append(RT_opencv)
202
+ normal = normal_opengl2opencv(normal)
203
+ normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
204
+ self.all_normals.append(normal)
205
+ self.all_normals_world.append(normal_world)
206
+
207
+ self.directions = torch.stack([self.directions] * len(self.all_images))
208
+ self.origins = self.directions
209
+ self.all_poses = np.stack(self.all_poses)
210
+ self.all_normals = np.stack(self.all_normals)
211
+ self.all_normals_world = np.stack(self.all_normals_world)
212
+ self.all_w2cs = np.stack(self.all_w2cs)
213
+
214
+ self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
215
+ self.all_images = self.all_images.to(self.rank)
216
+ self.all_fg_masks = self.all_fg_masks.to(self.rank)
217
+ self.all_rgb_masks = self.all_fg_masks.to(self.rank)
218
+ self.all_normals_world = (
219
+ torch.from_numpy(self.all_normals_world).float().to(self.rank)
220
+ )
221
+
222
+
223
+ class BlenderDataset(Dataset, BlenderDatasetBase):
224
+ def __init__(self, config, split):
225
+ self.setup(config, split)
226
+
227
+ def __len__(self):
228
+ return len(self.all_images)
229
+
230
+ def __getitem__(self, index):
231
+ return {"index": index}
232
+
233
+
234
+ class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
235
+ def __init__(self, config, split):
236
+ self.setup(config, split)
237
+
238
+ def __iter__(self):
239
+ while True:
240
+ yield {}
241
+
242
+
243
+ @datasets.register("v3d")
244
+ class BlenderDataModule(pl.LightningDataModule):
245
+ def __init__(self, config):
246
+ super().__init__()
247
+ self.config = config
248
+
249
+ def setup(self, stage=None):
250
+ if stage in [None, "fit"]:
251
+ self.train_dataset = BlenderIterableDataset(
252
+ self.config, self.config.train_split
253
+ )
254
+ if stage in [None, "fit", "validate"]:
255
+ self.val_dataset = BlenderDataset(self.config, self.config.val_split)
256
+ if stage in [None, "test"]:
257
+ self.test_dataset = BlenderDataset(self.config, self.config.test_split)
258
+ if stage in [None, "predict"]:
259
+ self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
260
+
261
+ def prepare_data(self):
262
+ pass
263
+
264
+ def general_loader(self, dataset, batch_size):
265
+ sampler = None
266
+ return DataLoader(
267
+ dataset,
268
+ num_workers=os.cpu_count(),
269
+ batch_size=batch_size,
270
+ pin_memory=True,
271
+ sampler=sampler,
272
+ )
273
+
274
+ def train_dataloader(self):
275
+ return self.general_loader(self.train_dataset, batch_size=1)
276
+
277
+ def val_dataloader(self):
278
+ return self.general_loader(self.val_dataset, batch_size=1)
279
+
280
+ def test_dataloader(self):
281
+ return self.general_loader(self.test_dataset, batch_size=1)
282
+
283
+ def predict_dataloader(self):
284
+ return self.general_loader(self.predict_dataset, batch_size=1)
mesh_recon/datasets/videonvs.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
9
+ import torchvision.transforms.functional as TF
10
+ from torchvision.utils import make_grid, save_image
11
+ from einops import rearrange
12
+
13
+ import pytorch_lightning as pl
14
+
15
+ import datasets
16
+ from models.ray_utils import get_ray_directions
17
+ from utils.misc import get_rank
18
+ from datasets.ortho import (
19
+ inv_RT,
20
+ camNormal2worldNormal,
21
+ RT_opengl2opencv,
22
+ normal_opengl2opencv,
23
+ )
24
+ from utils.dpt import DPT
25
+
26
+
27
+ def blender2midas(img):
28
+ """Blender: rub
29
+ midas: lub
30
+ """
31
+ img[..., 0] = -img[..., 0]
32
+ img[..., 1] = -img[..., 1]
33
+ img[..., -1] = -img[..., -1]
34
+ return img
35
+
36
+
37
+ def midas2blender(img):
38
+ """Blender: rub
39
+ midas: lub
40
+ """
41
+ img[..., 0] = -img[..., 0]
42
+ img[..., 1] = -img[..., 1]
43
+ img[..., -1] = -img[..., -1]
44
+ return img
45
+
46
+
47
+ class BlenderDatasetBase:
48
+ def setup(self, config, split):
49
+ self.config = config
50
+ self.rank = get_rank()
51
+
52
+ self.has_mask = True
53
+ self.apply_mask = True
54
+
55
+ dpt = DPT(device=self.rank, mode="normal")
56
+
57
+ with open(
58
+ os.path.join(
59
+ self.config.root_dir, self.config.scene, f"transforms_train.json"
60
+ ),
61
+ "r",
62
+ ) as f:
63
+ meta = json.load(f)
64
+
65
+ if "w" in meta and "h" in meta:
66
+ W, H = int(meta["w"]), int(meta["h"])
67
+ else:
68
+ W, H = 800, 800
69
+
70
+ if "img_wh" in self.config:
71
+ w, h = self.config.img_wh
72
+ assert round(W / w * h) == H
73
+ elif "img_downscale" in self.config:
74
+ w, h = W // self.config.img_downscale, H // self.config.img_downscale
75
+ else:
76
+ raise KeyError("Either img_wh or img_downscale should be specified.")
77
+
78
+ self.w, self.h = w, h
79
+ self.img_wh = (self.w, self.h)
80
+
81
+ # self.near, self.far = self.config.near_plane, self.config.far_plane
82
+
83
+ self.focal = (
84
+ 0.5 * w / math.tan(0.5 * meta["camera_angle_x"])
85
+ ) # scaled focal length
86
+
87
+ # ray directions for all pixels, same for all images (same H, W, focal)
88
+ self.directions = get_ray_directions(
89
+ self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
90
+ ).to(
91
+ self.rank
92
+ ) # (h, w, 3)
93
+
94
+ self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
95
+
96
+ for i, frame in enumerate(meta["frames"]):
97
+ c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
98
+ self.all_c2w.append(c2w)
99
+
100
+ img_path = os.path.join(
101
+ self.config.root_dir,
102
+ self.config.scene,
103
+ f"{frame['file_path']}.png",
104
+ )
105
+ img = Image.open(img_path)
106
+ img = img.resize(self.img_wh, Image.BICUBIC)
107
+ img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
108
+
109
+ self.all_fg_masks.append(img[..., -1]) # (h, w)
110
+ self.all_images.append(img[..., :3])
111
+
112
+ self.all_c2w, self.all_images, self.all_fg_masks = (
113
+ torch.stack(self.all_c2w, dim=0).float().to(self.rank),
114
+ torch.stack(self.all_images, dim=0).float().to(self.rank),
115
+ torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
116
+ )
117
+
118
+ self.normals = dpt(self.all_images)
119
+
120
+ self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
121
+
122
+ self.normals = self.normals * 2.0 - 1.0
123
+ self.normals = midas2blender(self.normals).cpu().numpy()
124
+ # self.normals = self.normals.cpu().numpy()
125
+ self.normals[..., 0] *= -1
126
+ self.normals[~self.all_masks] = [0, 0, 0]
127
+ normals = rearrange(self.normals, "b h w c -> b c h w")
128
+ normals = normals * 0.5 + 0.5
129
+ normals = torch.from_numpy(normals)
130
+ save_image(make_grid(normals, nrow=4), "tmp/normals.png")
131
+ # exit(0)
132
+
133
+ (
134
+ self.all_poses,
135
+ self.all_normals,
136
+ self.all_normals_world,
137
+ self.all_w2cs,
138
+ self.all_color_masks,
139
+ ) = ([], [], [], [], [])
140
+
141
+ for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
142
+ RT_opengl = inv_RT(c2w_opengl)
143
+ RT_opencv = RT_opengl2opencv(RT_opengl)
144
+ c2w_opencv = inv_RT(RT_opencv)
145
+ self.all_poses.append(c2w_opencv)
146
+ self.all_w2cs.append(RT_opencv)
147
+ normal = normal_opengl2opencv(normal)
148
+ normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
149
+ self.all_normals.append(normal)
150
+ self.all_normals_world.append(normal_world)
151
+
152
+ self.directions = torch.stack([self.directions] * len(self.all_images))
153
+ self.origins = self.directions
154
+ self.all_poses = np.stack(self.all_poses)
155
+ self.all_normals = np.stack(self.all_normals)
156
+ self.all_normals_world = np.stack(self.all_normals_world)
157
+ self.all_w2cs = np.stack(self.all_w2cs)
158
+
159
+ self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
160
+ self.all_images = self.all_images.to(self.rank)
161
+ self.all_fg_masks = self.all_fg_masks.to(self.rank)
162
+ self.all_rgb_masks = self.all_fg_masks.to(self.rank)
163
+ self.all_normals_world = (
164
+ torch.from_numpy(self.all_normals_world).float().to(self.rank)
165
+ )
166
+
167
+ # normals = rearrange(self.all_normals_world, "b h w c -> b c h w")
168
+ # normals = normals * 0.5 + 0.5
169
+ # # normals = torch.from_numpy(normals)
170
+ # save_image(make_grid(normals, nrow=4), "tmp/normals_world.png")
171
+ # # exit(0)
172
+
173
+ # # normals = (normals + 1) / 2.0
174
+ # # for debug
175
+ # index = [0, 9]
176
+ # self.all_poses = self.all_poses[index]
177
+ # self.all_c2w = self.all_c2w[index]
178
+ # self.all_normals_world = self.all_normals_world[index]
179
+ # self.all_w2cs = self.all_w2cs[index]
180
+ # self.rgb_masks = self.all_rgb_masks[index]
181
+ # self.fg_masks = self.all_fg_masks[index]
182
+ # self.all_images = self.all_images[index]
183
+ # self.directions = self.directions[index]
184
+ # self.origins = self.origins[index]
185
+
186
+ # images = rearrange(self.all_images, "b h w c -> b c h w")
187
+ # normals = rearrange(normals, "b h w c -> b c h w")
188
+ # save_image(make_grid(images, nrow=4), "tmp/images.png")
189
+ # save_image(make_grid(normals, nrow=4), "tmp/normals.png")
190
+ # breakpoint()
191
+
192
+ # self.normals = self.normals * 2.0 - 1.0
193
+
194
+
195
+ class BlenderDataset(Dataset, BlenderDatasetBase):
196
+ def __init__(self, config, split):
197
+ self.setup(config, split)
198
+
199
+ def __len__(self):
200
+ return len(self.all_images)
201
+
202
+ def __getitem__(self, index):
203
+ return {"index": index}
204
+
205
+
206
+ class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
207
+ def __init__(self, config, split):
208
+ self.setup(config, split)
209
+
210
+ def __iter__(self):
211
+ while True:
212
+ yield {}
213
+
214
+
215
+ @datasets.register("videonvs")
216
+ class BlenderDataModule(pl.LightningDataModule):
217
+ def __init__(self, config):
218
+ super().__init__()
219
+ self.config = config
220
+
221
+ def setup(self, stage=None):
222
+ if stage in [None, "fit"]:
223
+ self.train_dataset = BlenderIterableDataset(
224
+ self.config, self.config.train_split
225
+ )
226
+ if stage in [None, "fit", "validate"]:
227
+ self.val_dataset = BlenderDataset(self.config, self.config.val_split)
228
+ if stage in [None, "test"]:
229
+ self.test_dataset = BlenderDataset(self.config, self.config.test_split)
230
+ if stage in [None, "predict"]:
231
+ self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
232
+
233
+ def prepare_data(self):
234
+ pass
235
+
236
+ def general_loader(self, dataset, batch_size):
237
+ sampler = None
238
+ return DataLoader(
239
+ dataset,
240
+ num_workers=os.cpu_count(),
241
+ batch_size=batch_size,
242
+ pin_memory=True,
243
+ sampler=sampler,
244
+ )
245
+
246
+ def train_dataloader(self):
247
+ return self.general_loader(self.train_dataset, batch_size=1)
248
+
249
+ def val_dataloader(self):
250
+ return self.general_loader(self.val_dataset, batch_size=1)
251
+
252
+ def test_dataloader(self):
253
+ return self.general_loader(self.test_dataset, batch_size=1)
254
+
255
+ def predict_dataloader(self):
256
+ return self.general_loader(self.predict_dataset, batch_size=1)
mesh_recon/datasets/videonvs_co3d.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
9
+ import torchvision.transforms.functional as TF
10
+ from torchvision.utils import make_grid, save_image
11
+ from einops import rearrange
12
+ from rembg import remove, new_session
13
+
14
+ import pytorch_lightning as pl
15
+
16
+ import datasets
17
+ from models.ray_utils import get_ray_directions
18
+ from utils.misc import get_rank
19
+ from datasets.ortho import (
20
+ inv_RT,
21
+ camNormal2worldNormal,
22
+ RT_opengl2opencv,
23
+ normal_opengl2opencv,
24
+ )
25
+ from utils.dpt import DPT
26
+
27
+
28
+ def blender2midas(img):
29
+ """Blender: rub
30
+ midas: lub
31
+ """
32
+ img[..., 0] = -img[..., 0]
33
+ img[..., 1] = -img[..., 1]
34
+ img[..., -1] = -img[..., -1]
35
+ return img
36
+
37
+
38
+ def midas2blender(img):
39
+ """Blender: rub
40
+ midas: lub
41
+ """
42
+ img[..., 0] = -img[..., 0]
43
+ img[..., 1] = -img[..., 1]
44
+ img[..., -1] = -img[..., -1]
45
+ return img
46
+
47
+
48
+ class BlenderDatasetBase:
49
+ def setup(self, config, split):
50
+ self.config = config
51
+ self.rank = get_rank()
52
+
53
+ self.has_mask = True
54
+ self.apply_mask = True
55
+
56
+ dpt = DPT(device=self.rank, mode="normal")
57
+
58
+ self.directions = []
59
+ with open(
60
+ os.path.join(self.config.root_dir, self.config.scene, f"transforms.json"),
61
+ "r",
62
+ ) as f:
63
+ meta = json.load(f)
64
+
65
+ if "w" in meta and "h" in meta:
66
+ W, H = int(meta["w"]), int(meta["h"])
67
+ else:
68
+ W, H = 800, 800
69
+
70
+ if "img_wh" in self.config:
71
+ w, h = self.config.img_wh
72
+ assert round(W / w * h) == H
73
+ elif "img_downscale" in self.config:
74
+ w, h = W // self.config.img_downscale, H // self.config.img_downscale
75
+ else:
76
+ raise KeyError("Either img_wh or img_downscale should be specified.")
77
+
78
+ self.w, self.h = w, h
79
+ self.img_wh = (self.w, self.h)
80
+
81
+ # self.near, self.far = self.config.near_plane, self.config.far_plane
82
+ _session = new_session()
83
+ self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
84
+
85
+ for i, frame in enumerate(meta["frames"]):
86
+ c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
87
+ self.all_c2w.append(c2w)
88
+
89
+ img_path = os.path.join(
90
+ self.config.root_dir,
91
+ self.config.scene,
92
+ f"{frame['file_path']}",
93
+ )
94
+ img = Image.open(img_path)
95
+ img = remove(img, session=_session)
96
+ img = img.resize(self.img_wh, Image.BICUBIC)
97
+ img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
98
+ fx = frame["fl_x"]
99
+ fy = frame["fl_y"]
100
+ cx = frame["cx"]
101
+ cy = frame["cy"]
102
+
103
+ self.all_fg_masks.append(img[..., -1]) # (h, w)
104
+ self.all_images.append(img[..., :3])
105
+
106
+ self.directions.append(get_ray_directions(self.w, self.h, fx, fy, cx, cy))
107
+
108
+ self.all_c2w, self.all_images, self.all_fg_masks = (
109
+ torch.stack(self.all_c2w, dim=0).float().to(self.rank),
110
+ torch.stack(self.all_images, dim=0).float().to(self.rank),
111
+ torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
112
+ )
113
+
114
+ self.normals = dpt(self.all_images)
115
+
116
+ self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
117
+
118
+ self.normals = self.normals * 2.0 - 1.0
119
+ self.normals = midas2blender(self.normals).cpu().numpy()
120
+ # self.normals = self.normals.cpu().numpy()
121
+ self.normals[..., 0] *= -1
122
+ self.normals[~self.all_masks] = [0, 0, 0]
123
+ normals = rearrange(self.normals, "b h w c -> b c h w")
124
+ normals = normals * 0.5 + 0.5
125
+ normals = torch.from_numpy(normals)
126
+ save_image(make_grid(normals, nrow=4), "tmp/normals.png")
127
+ # exit(0)
128
+
129
+ (
130
+ self.all_poses,
131
+ self.all_normals,
132
+ self.all_normals_world,
133
+ self.all_w2cs,
134
+ self.all_color_masks,
135
+ ) = ([], [], [], [], [])
136
+
137
+ for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
138
+ RT_opengl = inv_RT(c2w_opengl)
139
+ RT_opencv = RT_opengl2opencv(RT_opengl)
140
+ c2w_opencv = inv_RT(RT_opencv)
141
+ self.all_poses.append(c2w_opencv)
142
+ self.all_w2cs.append(RT_opencv)
143
+ normal = normal_opengl2opencv(normal)
144
+ normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
145
+ self.all_normals.append(normal)
146
+ self.all_normals_world.append(normal_world)
147
+
148
+ self.directions = torch.stack(self.directions).to(self.rank)
149
+ self.origins = self.directions
150
+ self.all_poses = np.stack(self.all_poses)
151
+ self.all_normals = np.stack(self.all_normals)
152
+ self.all_normals_world = np.stack(self.all_normals_world)
153
+ self.all_w2cs = np.stack(self.all_w2cs)
154
+
155
+ self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
156
+ self.all_images = self.all_images.to(self.rank)
157
+ self.all_fg_masks = self.all_fg_masks.to(self.rank)
158
+ self.all_rgb_masks = self.all_fg_masks.to(self.rank)
159
+ self.all_normals_world = (
160
+ torch.from_numpy(self.all_normals_world).float().to(self.rank)
161
+ )
162
+
163
+ # normals = rearrange(self.all_normals_world, "b h w c -> b c h w")
164
+ # normals = normals * 0.5 + 0.5
165
+ # # normals = torch.from_numpy(normals)
166
+ # save_image(make_grid(normals, nrow=4), "tmp/normals_world.png")
167
+ # # exit(0)
168
+
169
+ # # normals = (normals + 1) / 2.0
170
+ # # for debug
171
+ # index = [0, 9]
172
+ # self.all_poses = self.all_poses[index]
173
+ # self.all_c2w = self.all_c2w[index]
174
+ # self.all_normals_world = self.all_normals_world[index]
175
+ # self.all_w2cs = self.all_w2cs[index]
176
+ # self.rgb_masks = self.all_rgb_masks[index]
177
+ # self.fg_masks = self.all_fg_masks[index]
178
+ # self.all_images = self.all_images[index]
179
+ # self.directions = self.directions[index]
180
+ # self.origins = self.origins[index]
181
+
182
+ # images = rearrange(self.all_images, "b h w c -> b c h w")
183
+ # normals = rearrange(normals, "b h w c -> b c h w")
184
+ # save_image(make_grid(images, nrow=4), "tmp/images.png")
185
+ # save_image(make_grid(normals, nrow=4), "tmp/normals.png")
186
+ # breakpoint()
187
+
188
+ # self.normals = self.normals * 2.0 - 1.0
189
+
190
+
191
+ class BlenderDataset(Dataset, BlenderDatasetBase):
192
+ def __init__(self, config, split):
193
+ self.setup(config, split)
194
+
195
+ def __len__(self):
196
+ return len(self.all_images)
197
+
198
+ def __getitem__(self, index):
199
+ return {"index": index}
200
+
201
+
202
+ class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
203
+ def __init__(self, config, split):
204
+ self.setup(config, split)
205
+
206
+ def __iter__(self):
207
+ while True:
208
+ yield {}
209
+
210
+
211
+ @datasets.register("videonvs-scene")
212
+ class VideoNVSScene(pl.LightningDataModule):
213
+ def __init__(self, config):
214
+ super().__init__()
215
+ self.config = config
216
+
217
+ def setup(self, stage=None):
218
+ if stage in [None, "fit"]:
219
+ self.train_dataset = BlenderIterableDataset(
220
+ self.config, self.config.train_split
221
+ )
222
+ if stage in [None, "fit", "validate"]:
223
+ self.val_dataset = BlenderDataset(self.config, self.config.val_split)
224
+ if stage in [None, "test"]:
225
+ self.test_dataset = BlenderDataset(self.config, self.config.test_split)
226
+ if stage in [None, "predict"]:
227
+ self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
228
+
229
+ def prepare_data(self):
230
+ pass
231
+
232
+ def general_loader(self, dataset, batch_size):
233
+ sampler = None
234
+ return DataLoader(
235
+ dataset,
236
+ num_workers=os.cpu_count(),
237
+ batch_size=batch_size,
238
+ pin_memory=True,
239
+ sampler=sampler,
240
+ )
241
+
242
+ def train_dataloader(self):
243
+ return self.general_loader(self.train_dataset, batch_size=1)
244
+
245
+ def val_dataloader(self):
246
+ return self.general_loader(self.val_dataset, batch_size=1)
247
+
248
+ def test_dataloader(self):
249
+ return self.general_loader(self.test_dataset, batch_size=1)
250
+
251
+ def predict_dataloader(self):
252
+ return self.general_loader(self.predict_dataset, batch_size=1)
mesh_recon/launch.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import os
4
+ import time
5
+ import logging
6
+ from datetime import datetime
7
+
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--config", required=True, help="path to config file")
12
+ parser.add_argument("--gpu", default="0", help="GPU(s) to be used")
13
+ parser.add_argument(
14
+ "--resume", default=None, help="path to the weights to be resumed"
15
+ )
16
+ parser.add_argument(
17
+ "--resume_weights_only",
18
+ action="store_true",
19
+ help="specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only",
20
+ )
21
+
22
+ group = parser.add_mutually_exclusive_group(required=True)
23
+ group.add_argument("--train", action="store_true")
24
+ group.add_argument("--validate", action="store_true")
25
+ group.add_argument("--test", action="store_true")
26
+ group.add_argument("--predict", action="store_true")
27
+ # group.add_argument('--export', action='store_true') # TODO: a separate export action
28
+
29
+ parser.add_argument("--exp_dir", default="./exp")
30
+ parser.add_argument("--runs_dir", default="./runs")
31
+ parser.add_argument(
32
+ "--verbose", action="store_true", help="if true, set logging level to DEBUG"
33
+ )
34
+
35
+ args, extras = parser.parse_known_args()
36
+
37
+ # set CUDA_VISIBLE_DEVICES then import pytorch-lightning
38
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
39
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
40
+ n_gpus = len(args.gpu.split(","))
41
+
42
+ import datasets
43
+ import systems
44
+ import pytorch_lightning as pl
45
+ from pytorch_lightning import Trainer
46
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
47
+ from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
48
+ from utils.callbacks import (
49
+ CodeSnapshotCallback,
50
+ ConfigSnapshotCallback,
51
+ CustomProgressBar,
52
+ )
53
+ from utils.misc import load_config
54
+
55
+ # parse YAML config to OmegaConf
56
+ config = load_config(args.config, cli_args=extras)
57
+ config.cmd_args = vars(args)
58
+
59
+ config.trial_name = config.get("trial_name") or (
60
+ config.tag + datetime.now().strftime("@%Y%m%d-%H%M%S")
61
+ )
62
+ config.exp_dir = config.get("exp_dir") or os.path.join(args.exp_dir, config.name)
63
+ config.save_dir = config.get("save_dir") or os.path.join(
64
+ config.exp_dir, config.trial_name, "save"
65
+ )
66
+ config.ckpt_dir = config.get("ckpt_dir") or os.path.join(
67
+ config.exp_dir, config.trial_name, "ckpt"
68
+ )
69
+ config.code_dir = config.get("code_dir") or os.path.join(
70
+ config.exp_dir, config.trial_name, "code"
71
+ )
72
+ config.config_dir = config.get("config_dir") or os.path.join(
73
+ config.exp_dir, config.trial_name, "config"
74
+ )
75
+
76
+ logger = logging.getLogger("pytorch_lightning")
77
+ if args.verbose:
78
+ logger.setLevel(logging.DEBUG)
79
+
80
+ if "seed" not in config:
81
+ config.seed = int(time.time() * 1000) % 1000
82
+ pl.seed_everything(config.seed)
83
+
84
+ dm = datasets.make(config.dataset.name, config.dataset)
85
+ system = systems.make(
86
+ config.system.name,
87
+ config,
88
+ load_from_checkpoint=None if not args.resume_weights_only else args.resume,
89
+ )
90
+
91
+ callbacks = []
92
+ if args.train:
93
+ callbacks += [
94
+ ModelCheckpoint(dirpath=config.ckpt_dir, **config.checkpoint),
95
+ LearningRateMonitor(logging_interval="step"),
96
+ # CodeSnapshotCallback(
97
+ # config.code_dir, use_version=False
98
+ # ),
99
+ ConfigSnapshotCallback(config, config.config_dir, use_version=False),
100
+ CustomProgressBar(refresh_rate=1),
101
+ ]
102
+
103
+ loggers = []
104
+ if args.train:
105
+ loggers += [
106
+ TensorBoardLogger(
107
+ args.runs_dir, name=config.name, version=config.trial_name
108
+ ),
109
+ CSVLogger(config.exp_dir, name=config.trial_name, version="csv_logs"),
110
+ ]
111
+
112
+ if sys.platform == "win32":
113
+ # does not support multi-gpu on windows
114
+ strategy = "dp"
115
+ assert n_gpus == 1
116
+ else:
117
+ strategy = "ddp_find_unused_parameters_false"
118
+
119
+ trainer = Trainer(
120
+ devices=n_gpus,
121
+ accelerator="gpu",
122
+ callbacks=callbacks,
123
+ logger=loggers,
124
+ strategy=strategy,
125
+ **config.trainer
126
+ )
127
+
128
+ if args.train:
129
+ if args.resume and not args.resume_weights_only:
130
+ # FIXME: different behavior in pytorch-lighting>1.9 ?
131
+ trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
132
+ else:
133
+ trainer.fit(system, datamodule=dm)
134
+ trainer.test(system, datamodule=dm)
135
+ elif args.validate:
136
+ trainer.validate(system, datamodule=dm, ckpt_path=args.resume)
137
+ elif args.test:
138
+ trainer.test(system, datamodule=dm, ckpt_path=args.resume)
139
+ elif args.predict:
140
+ trainer.predict(system, datamodule=dm, ckpt_path=args.resume)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
mesh_recon/mesh.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import trimesh
5
+ import numpy as np
6
+
7
+ from kiui.op import safe_normalize, dot
8
+ from kiui.typing import *
9
+
10
+ class Mesh:
11
+ """
12
+ A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
13
+
14
+ Note:
15
+ This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
16
+ """
17
+ def __init__(
18
+ self,
19
+ v: Optional[Tensor] = None,
20
+ f: Optional[Tensor] = None,
21
+ vn: Optional[Tensor] = None,
22
+ fn: Optional[Tensor] = None,
23
+ vt: Optional[Tensor] = None,
24
+ ft: Optional[Tensor] = None,
25
+ vc: Optional[Tensor] = None, # vertex color
26
+ albedo: Optional[Tensor] = None,
27
+ metallicRoughness: Optional[Tensor] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ """Init a mesh directly using all attributes.
31
+
32
+ Args:
33
+ v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
34
+ f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
35
+ vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
36
+ fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
37
+ vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
38
+ ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
39
+ vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
40
+ albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
41
+ metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
42
+ device (Optional[torch.device]): torch device. Defaults to None.
43
+ """
44
+ self.device = device
45
+ self.v = v
46
+ self.vn = vn
47
+ self.vt = vt
48
+ self.f = f
49
+ self.fn = fn
50
+ self.ft = ft
51
+ # will first see if there is vertex color to use
52
+ self.vc = vc
53
+ # only support a single albedo image
54
+ self.albedo = albedo
55
+ # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
56
+ # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
57
+ self.metallicRoughness = metallicRoughness
58
+
59
+ self.ori_center = 0
60
+ self.ori_scale = 1
61
+
62
+ @classmethod
63
+ def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
64
+ """load mesh from path.
65
+
66
+ Args:
67
+ path (str): path to mesh file, supports ply, obj, glb.
68
+ clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
69
+ resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
70
+ renormal (bool, optional): re-calc the vertex normals. Defaults to True.
71
+ retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
72
+ bound (float, optional): bound to resize. Defaults to 0.9.
73
+ front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
74
+ device (torch.device, optional): torch device. Defaults to None.
75
+
76
+ Note:
77
+ a ``device`` keyword argument can be provided to specify the torch device.
78
+ If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
79
+
80
+ Returns:
81
+ Mesh: the loaded Mesh object.
82
+ """
83
+ # obj supports face uv
84
+ if path.endswith(".obj"):
85
+ mesh = cls.load_obj(path, **kwargs)
86
+ # trimesh only supports vertex uv, but can load more formats
87
+ else:
88
+ mesh = cls.load_trimesh(path, **kwargs)
89
+
90
+ # clean
91
+ if clean:
92
+ from kiui.mesh_utils import clean_mesh
93
+ vertices = mesh.v.detach().cpu().numpy()
94
+ triangles = mesh.f.detach().cpu().numpy()
95
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
96
+ mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
97
+ mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
98
+
99
+ print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
100
+ # auto-normalize
101
+ if resize:
102
+ mesh.auto_size(bound=bound)
103
+ # auto-fix normal
104
+ if renormal or mesh.vn is None:
105
+ mesh.auto_normal()
106
+ print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
107
+ # auto-fix texcoords
108
+ if retex or (mesh.albedo is not None and mesh.vt is None):
109
+ mesh.auto_uv(cache_path=path)
110
+ print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
111
+
112
+ # rotate front dir to +z
113
+ if front_dir != "+z":
114
+ # axis switch
115
+ if "-z" in front_dir:
116
+ T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
117
+ elif "+x" in front_dir:
118
+ T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
119
+ elif "-x" in front_dir:
120
+ T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
121
+ elif "+y" in front_dir:
122
+ T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
123
+ elif "-y" in front_dir:
124
+ T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
125
+ else:
126
+ T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
127
+ # rotation (how many 90 degrees)
128
+ if '1' in front_dir:
129
+ T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
130
+ elif '2' in front_dir:
131
+ T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
132
+ elif '3' in front_dir:
133
+ T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
134
+ mesh.v @= T
135
+ mesh.vn @= T
136
+
137
+ return mesh
138
+
139
+ # load from obj file
140
+ @classmethod
141
+ def load_obj(cls, path, albedo_path=None, device=None):
142
+ """load an ``obj`` mesh.
143
+
144
+ Args:
145
+ path (str): path to mesh.
146
+ albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
147
+ device (torch.device, optional): torch device. Defaults to None.
148
+
149
+ Note:
150
+ We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
151
+ The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
152
+
153
+ Returns:
154
+ Mesh: the loaded Mesh object.
155
+ """
156
+ assert os.path.splitext(path)[-1] == ".obj"
157
+
158
+ mesh = cls()
159
+
160
+ # device
161
+ if device is None:
162
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163
+
164
+ mesh.device = device
165
+
166
+ # load obj
167
+ with open(path, "r") as f:
168
+ lines = f.readlines()
169
+
170
+ def parse_f_v(fv):
171
+ # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
172
+ # supported forms:
173
+ # f v1 v2 v3
174
+ # f v1/vt1 v2/vt2 v3/vt3
175
+ # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
176
+ # f v1//vn1 v2//vn2 v3//vn3
177
+ xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
178
+ xs.extend([-1] * (3 - len(xs)))
179
+ return xs[0], xs[1], xs[2]
180
+
181
+ vertices, texcoords, normals = [], [], []
182
+ faces, tfaces, nfaces = [], [], []
183
+ mtl_path = None
184
+
185
+ for line in lines:
186
+ split_line = line.split()
187
+ # empty line
188
+ if len(split_line) == 0:
189
+ continue
190
+ prefix = split_line[0].lower()
191
+ # mtllib
192
+ if prefix == "mtllib":
193
+ mtl_path = split_line[1]
194
+ # usemtl
195
+ elif prefix == "usemtl":
196
+ pass # ignored
197
+ # v/vn/vt
198
+ elif prefix == "v":
199
+ vertices.append([float(v) for v in split_line[1:]])
200
+ elif prefix == "vn":
201
+ normals.append([float(v) for v in split_line[1:]])
202
+ elif prefix == "vt":
203
+ val = [float(v) for v in split_line[1:]]
204
+ texcoords.append([val[0], 1.0 - val[1]])
205
+ elif prefix == "f":
206
+ vs = split_line[1:]
207
+ nv = len(vs)
208
+ v0, t0, n0 = parse_f_v(vs[0])
209
+ for i in range(nv - 2): # triangulate (assume vertices are ordered)
210
+ v1, t1, n1 = parse_f_v(vs[i + 1])
211
+ v2, t2, n2 = parse_f_v(vs[i + 2])
212
+ faces.append([v0, v1, v2])
213
+ tfaces.append([t0, t1, t2])
214
+ nfaces.append([n0, n1, n2])
215
+
216
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
217
+ mesh.vt = (
218
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
219
+ if len(texcoords) > 0
220
+ else None
221
+ )
222
+ mesh.vn = (
223
+ torch.tensor(normals, dtype=torch.float32, device=device)
224
+ if len(normals) > 0
225
+ else None
226
+ )
227
+
228
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
229
+ mesh.ft = (
230
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
231
+ if len(texcoords) > 0
232
+ else None
233
+ )
234
+ mesh.fn = (
235
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
236
+ if len(normals) > 0
237
+ else None
238
+ )
239
+
240
+ # see if there is vertex color
241
+ use_vertex_color = False
242
+ if mesh.v.shape[1] == 6:
243
+ use_vertex_color = True
244
+ mesh.vc = mesh.v[:, 3:]
245
+ mesh.v = mesh.v[:, :3]
246
+ print(f"[load_obj] use vertex color: {mesh.vc.shape}")
247
+
248
+ # try to load texture image
249
+ if not use_vertex_color:
250
+ # try to retrieve mtl file
251
+ mtl_path_candidates = []
252
+ if mtl_path is not None:
253
+ mtl_path_candidates.append(mtl_path)
254
+ mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
255
+ mtl_path_candidates.append(path.replace(".obj", ".mtl"))
256
+
257
+ mtl_path = None
258
+ for candidate in mtl_path_candidates:
259
+ if os.path.exists(candidate):
260
+ mtl_path = candidate
261
+ break
262
+
263
+ # if albedo_path is not provided, try retrieve it from mtl
264
+ metallic_path = None
265
+ roughness_path = None
266
+ if mtl_path is not None and albedo_path is None:
267
+ with open(mtl_path, "r") as f:
268
+ lines = f.readlines()
269
+
270
+ for line in lines:
271
+ split_line = line.split()
272
+ # empty line
273
+ if len(split_line) == 0:
274
+ continue
275
+ prefix = split_line[0]
276
+
277
+ if "map_Kd" in prefix:
278
+ # assume relative path!
279
+ albedo_path = os.path.join(os.path.dirname(path), split_line[1])
280
+ print(f"[load_obj] use texture from: {albedo_path}")
281
+ elif "map_Pm" in prefix:
282
+ metallic_path = os.path.join(os.path.dirname(path), split_line[1])
283
+ elif "map_Pr" in prefix:
284
+ roughness_path = os.path.join(os.path.dirname(path), split_line[1])
285
+
286
+ # still not found albedo_path, or the path doesn't exist
287
+ if albedo_path is None or not os.path.exists(albedo_path):
288
+ # init an empty texture
289
+ print(f"[load_obj] init empty albedo!")
290
+ # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
291
+ albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
292
+ else:
293
+ albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
294
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
295
+ albedo = albedo.astype(np.float32) / 255
296
+ print(f"[load_obj] load texture: {albedo.shape}")
297
+
298
+ mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
299
+
300
+ # try to load metallic and roughness
301
+ if metallic_path is not None and roughness_path is not None:
302
+ print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
303
+ metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
304
+ metallic = metallic.astype(np.float32) / 255
305
+ roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
306
+ roughness = roughness.astype(np.float32) / 255
307
+ metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
308
+
309
+ mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
310
+
311
+ return mesh
312
+
313
+ @classmethod
314
+ def load_trimesh(cls, path, device=None):
315
+ """load a mesh using ``trimesh.load()``.
316
+
317
+ Can load various formats like ``glb`` and serves as a fallback.
318
+
319
+ Note:
320
+ We will try to merge all meshes if the glb contains more than one,
321
+ but **this may cause the texture to lose**, since we only support one texture image!
322
+
323
+ Args:
324
+ path (str): path to the mesh file.
325
+ device (torch.device, optional): torch device. Defaults to None.
326
+
327
+ Returns:
328
+ Mesh: the loaded Mesh object.
329
+ """
330
+ mesh = cls()
331
+
332
+ # device
333
+ if device is None:
334
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
335
+
336
+ mesh.device = device
337
+
338
+ # use trimesh to load ply/glb
339
+ _data = trimesh.load(path)
340
+ if isinstance(_data, trimesh.Scene):
341
+ if len(_data.geometry) == 1:
342
+ _mesh = list(_data.geometry.values())[0]
343
+ else:
344
+ print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
345
+ _concat = []
346
+ # loop the scene graph and apply transform to each mesh
347
+ scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
348
+ for k, v in scene_graph.items():
349
+ name = v['geometry']
350
+ if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
351
+ transform = v['transform']
352
+ _concat.append(_data.geometry[name].apply_transform(transform))
353
+ _mesh = trimesh.util.concatenate(_concat)
354
+ else:
355
+ _mesh = _data
356
+
357
+ if _mesh.visual.kind == 'vertex':
358
+ vertex_colors = _mesh.visual.vertex_colors
359
+ vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
360
+ mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
361
+ print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
362
+ elif _mesh.visual.kind == 'texture':
363
+ _material = _mesh.visual.material
364
+ if isinstance(_material, trimesh.visual.material.PBRMaterial):
365
+ texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
366
+ # load metallicRoughness if present
367
+ if _material.metallicRoughnessTexture is not None:
368
+ metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
369
+ mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
370
+ elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
371
+ texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
372
+ else:
373
+ raise NotImplementedError(f"material type {type(_material)} not supported!")
374
+ mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
375
+ print(f"[load_trimesh] load texture: {texture.shape}")
376
+ else:
377
+ texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
378
+ mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
379
+ print(f"[load_trimesh] failed to load texture.")
380
+
381
+ vertices = _mesh.vertices
382
+
383
+ try:
384
+ texcoords = _mesh.visual.uv
385
+ texcoords[:, 1] = 1 - texcoords[:, 1]
386
+ except Exception as e:
387
+ texcoords = None
388
+
389
+ try:
390
+ normals = _mesh.vertex_normals
391
+ except Exception as e:
392
+ normals = None
393
+
394
+ # trimesh only support vertex uv...
395
+ faces = tfaces = nfaces = _mesh.faces
396
+
397
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
398
+ mesh.vt = (
399
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
400
+ if texcoords is not None
401
+ else None
402
+ )
403
+ mesh.vn = (
404
+ torch.tensor(normals, dtype=torch.float32, device=device)
405
+ if normals is not None
406
+ else None
407
+ )
408
+
409
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
410
+ mesh.ft = (
411
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
412
+ if texcoords is not None
413
+ else None
414
+ )
415
+ mesh.fn = (
416
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
417
+ if normals is not None
418
+ else None
419
+ )
420
+
421
+ return mesh
422
+
423
+ # sample surface (using trimesh)
424
+ def sample_surface(self, count: int):
425
+ """sample points on the surface of the mesh.
426
+
427
+ Args:
428
+ count (int): number of points to sample.
429
+
430
+ Returns:
431
+ torch.Tensor: the sampled points, float [count, 3].
432
+ """
433
+ _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
434
+ points, face_idx = trimesh.sample.sample_surface(_mesh, count)
435
+ points = torch.from_numpy(points).float().to(self.device)
436
+ return points
437
+
438
+ # aabb
439
+ def aabb(self):
440
+ """get the axis-aligned bounding box of the mesh.
441
+
442
+ Returns:
443
+ Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
444
+ """
445
+ return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
446
+
447
+ # unit size
448
+ @torch.no_grad()
449
+ def auto_size(self, bound=0.9):
450
+ """auto resize the mesh.
451
+
452
+ Args:
453
+ bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
454
+ """
455
+ vmin, vmax = self.aabb()
456
+ self.ori_center = (vmax + vmin) / 2
457
+ self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
458
+ self.v = (self.v - self.ori_center) * self.ori_scale
459
+
460
+ def auto_normal(self):
461
+ """auto calculate the vertex normals.
462
+ """
463
+ i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
464
+ v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
465
+
466
+ face_normals = torch.cross(v1 - v0, v2 - v0)
467
+
468
+ # Splat face normals to vertices
469
+ vn = torch.zeros_like(self.v)
470
+ vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
471
+ vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
472
+ vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
473
+
474
+ # Normalize, replace zero (degenerated) normals with some default value
475
+ vn = torch.where(
476
+ dot(vn, vn) > 1e-20,
477
+ vn,
478
+ torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
479
+ )
480
+ vn = safe_normalize(vn)
481
+
482
+ self.vn = vn
483
+ self.fn = self.f
484
+
485
+ def auto_uv(self, cache_path=None, vmap=True):
486
+ """auto calculate the uv coordinates.
487
+
488
+ Args:
489
+ cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
490
+ vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
491
+ Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
492
+ """
493
+ # try to load cache
494
+ if cache_path is not None:
495
+ cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
496
+ if cache_path is not None and os.path.exists(cache_path):
497
+ data = np.load(cache_path)
498
+ vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
499
+ else:
500
+ import xatlas
501
+
502
+ v_np = self.v.detach().cpu().numpy()
503
+ f_np = self.f.detach().int().cpu().numpy()
504
+ atlas = xatlas.Atlas()
505
+ atlas.add_mesh(v_np, f_np)
506
+ chart_options = xatlas.ChartOptions()
507
+ # chart_options.max_iterations = 4
508
+ atlas.generate(chart_options=chart_options)
509
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
510
+
511
+ # save to cache
512
+ if cache_path is not None:
513
+ np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
514
+
515
+ vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
516
+ ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
517
+ self.vt = vt
518
+ self.ft = ft
519
+
520
+ if vmap:
521
+ vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
522
+ self.align_v_to_vt(vmapping)
523
+
524
+ def align_v_to_vt(self, vmapping=None):
525
+ """ remap v/f and vn/fn to vt/ft.
526
+
527
+ Args:
528
+ vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
529
+ """
530
+ if vmapping is None:
531
+ ft = self.ft.view(-1).long()
532
+ f = self.f.view(-1).long()
533
+ vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
534
+ vmapping[ft] = f # scatter, randomly choose one if index is not unique
535
+
536
+ self.v = self.v[vmapping]
537
+ self.f = self.ft
538
+
539
+ if self.vn is not None:
540
+ self.vn = self.vn[vmapping]
541
+ self.fn = self.ft
542
+
543
+ def to(self, device):
544
+ """move all tensor attributes to device.
545
+
546
+ Args:
547
+ device (torch.device): target device.
548
+
549
+ Returns:
550
+ Mesh: self.
551
+ """
552
+ self.device = device
553
+ for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
554
+ tensor = getattr(self, name)
555
+ if tensor is not None:
556
+ setattr(self, name, tensor.to(device))
557
+ return self
558
+
559
+ def write(self, path):
560
+ """write the mesh to a path.
561
+
562
+ Args:
563
+ path (str): path to write, supports ply, obj and glb.
564
+ """
565
+ if path.endswith(".ply"):
566
+ self.write_ply(path)
567
+ elif path.endswith(".obj"):
568
+ self.write_obj(path)
569
+ elif path.endswith(".glb") or path.endswith(".gltf"):
570
+ self.write_glb(path)
571
+ else:
572
+ raise NotImplementedError(f"format {path} not supported!")
573
+
574
+ def write_ply(self, path):
575
+ """write the mesh in ply format. Only for geometry!
576
+
577
+ Args:
578
+ path (str): path to write.
579
+ """
580
+
581
+ if self.albedo is not None:
582
+ print(f'[WARN] ply format does not support exporting texture, will ignore!')
583
+
584
+ v_np = self.v.detach().cpu().numpy()
585
+ f_np = self.f.detach().cpu().numpy()
586
+
587
+ _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
588
+ _mesh.export(path)
589
+
590
+
591
+ def write_glb(self, path):
592
+ """write the mesh in glb/gltf format.
593
+ This will create a scene with a single mesh.
594
+
595
+ Args:
596
+ path (str): path to write.
597
+ """
598
+
599
+ # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
600
+ if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
601
+ self.align_v_to_vt()
602
+
603
+ import pygltflib
604
+
605
+ f_np = self.f.detach().cpu().numpy().astype(np.uint32)
606
+ f_np_blob = f_np.flatten().tobytes()
607
+
608
+ v_np = self.v.detach().cpu().numpy().astype(np.float32)
609
+ v_np_blob = v_np.tobytes()
610
+
611
+ blob = f_np_blob + v_np_blob
612
+ byteOffset = len(blob)
613
+
614
+ # base mesh
615
+ gltf = pygltflib.GLTF2(
616
+ scene=0,
617
+ scenes=[pygltflib.Scene(nodes=[0])],
618
+ nodes=[pygltflib.Node(mesh=0)],
619
+ meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
620
+ # indices to accessors (0 is triangles)
621
+ attributes=pygltflib.Attributes(
622
+ POSITION=1,
623
+ ),
624
+ indices=0,
625
+ )])],
626
+ buffers=[
627
+ pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
628
+ ],
629
+ # buffer view (based on dtype)
630
+ bufferViews=[
631
+ # triangles; as flatten (element) array
632
+ pygltflib.BufferView(
633
+ buffer=0,
634
+ byteLength=len(f_np_blob),
635
+ target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
636
+ ),
637
+ # positions; as vec3 array
638
+ pygltflib.BufferView(
639
+ buffer=0,
640
+ byteOffset=len(f_np_blob),
641
+ byteLength=len(v_np_blob),
642
+ byteStride=12, # vec3
643
+ target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
644
+ ),
645
+ ],
646
+ accessors=[
647
+ # 0 = triangles
648
+ pygltflib.Accessor(
649
+ bufferView=0,
650
+ componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
651
+ count=f_np.size,
652
+ type=pygltflib.SCALAR,
653
+ max=[int(f_np.max())],
654
+ min=[int(f_np.min())],
655
+ ),
656
+ # 1 = positions
657
+ pygltflib.Accessor(
658
+ bufferView=1,
659
+ componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
660
+ count=len(v_np),
661
+ type=pygltflib.VEC3,
662
+ max=v_np.max(axis=0).tolist(),
663
+ min=v_np.min(axis=0).tolist(),
664
+ ),
665
+ ],
666
+ )
667
+
668
+ # append texture info
669
+ if self.vt is not None:
670
+
671
+ vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
672
+ vt_np_blob = vt_np.tobytes()
673
+
674
+ albedo = self.albedo.detach().cpu().numpy()
675
+ albedo = (albedo * 255).astype(np.uint8)
676
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
677
+ albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
678
+
679
+ # update primitive
680
+ gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
681
+ gltf.meshes[0].primitives[0].material = 0
682
+
683
+ # update materials
684
+ gltf.materials.append(pygltflib.Material(
685
+ pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
686
+ baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
687
+ metallicFactor=0.0,
688
+ roughnessFactor=1.0,
689
+ ),
690
+ alphaMode=pygltflib.OPAQUE,
691
+ alphaCutoff=None,
692
+ doubleSided=True,
693
+ ))
694
+
695
+ gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
696
+ gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
697
+ gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
698
+
699
+ # update buffers
700
+ gltf.bufferViews.append(
701
+ # index = 2, texcoords; as vec2 array
702
+ pygltflib.BufferView(
703
+ buffer=0,
704
+ byteOffset=byteOffset,
705
+ byteLength=len(vt_np_blob),
706
+ byteStride=8, # vec2
707
+ target=pygltflib.ARRAY_BUFFER,
708
+ )
709
+ )
710
+
711
+ gltf.accessors.append(
712
+ # 2 = texcoords
713
+ pygltflib.Accessor(
714
+ bufferView=2,
715
+ componentType=pygltflib.FLOAT,
716
+ count=len(vt_np),
717
+ type=pygltflib.VEC2,
718
+ max=vt_np.max(axis=0).tolist(),
719
+ min=vt_np.min(axis=0).tolist(),
720
+ )
721
+ )
722
+
723
+ blob += vt_np_blob
724
+ byteOffset += len(vt_np_blob)
725
+
726
+ gltf.bufferViews.append(
727
+ # index = 3, albedo texture; as none target
728
+ pygltflib.BufferView(
729
+ buffer=0,
730
+ byteOffset=byteOffset,
731
+ byteLength=len(albedo_blob),
732
+ )
733
+ )
734
+
735
+ blob += albedo_blob
736
+ byteOffset += len(albedo_blob)
737
+
738
+ gltf.buffers[0].byteLength = byteOffset
739
+
740
+ # append metllic roughness
741
+ if self.metallicRoughness is not None:
742
+ metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
743
+ metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
744
+ metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
745
+ metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
746
+
747
+ # update texture definition
748
+ gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
749
+ gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
750
+ gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
751
+
752
+ gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
753
+ gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
754
+ gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
755
+
756
+ # update buffers
757
+ gltf.bufferViews.append(
758
+ # index = 4, metallicRoughness texture; as none target
759
+ pygltflib.BufferView(
760
+ buffer=0,
761
+ byteOffset=byteOffset,
762
+ byteLength=len(metallicRoughness_blob),
763
+ )
764
+ )
765
+
766
+ blob += metallicRoughness_blob
767
+ byteOffset += len(metallicRoughness_blob)
768
+
769
+ gltf.buffers[0].byteLength = byteOffset
770
+
771
+
772
+ # set actual data
773
+ gltf.set_binary_blob(blob)
774
+
775
+ # glb = b"".join(gltf.save_to_bytes())
776
+ gltf.save(path)
777
+
778
+
779
+ def write_obj(self, path):
780
+ """write the mesh in obj format. Will also write the texture and mtl files.
781
+
782
+ Args:
783
+ path (str): path to write.
784
+ """
785
+
786
+ mtl_path = path.replace(".obj", ".mtl")
787
+ albedo_path = path.replace(".obj", "_albedo.png")
788
+ metallic_path = path.replace(".obj", "_metallic.png")
789
+ roughness_path = path.replace(".obj", "_roughness.png")
790
+
791
+ v_np = self.v.detach().cpu().numpy()
792
+ vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
793
+ vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
794
+ f_np = self.f.detach().cpu().numpy()
795
+ ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
796
+ fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
797
+
798
+ with open(path, "w") as fp:
799
+ fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
800
+
801
+ for v in v_np:
802
+ fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
803
+
804
+ if vt_np is not None:
805
+ for v in vt_np:
806
+ fp.write(f"vt {v[0]} {1 - v[1]} \n")
807
+
808
+ if vn_np is not None:
809
+ for v in vn_np:
810
+ fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
811
+
812
+ fp.write(f"usemtl defaultMat \n")
813
+ for i in range(len(f_np)):
814
+ fp.write(
815
+ f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
816
+ {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
817
+ {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
818
+ )
819
+
820
+ with open(mtl_path, "w") as fp:
821
+ fp.write(f"newmtl defaultMat \n")
822
+ fp.write(f"Ka 1 1 1 \n")
823
+ fp.write(f"Kd 1 1 1 \n")
824
+ fp.write(f"Ks 0 0 0 \n")
825
+ fp.write(f"Tr 1 \n")
826
+ fp.write(f"illum 1 \n")
827
+ fp.write(f"Ns 0 \n")
828
+ if self.albedo is not None:
829
+ fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
830
+ if self.metallicRoughness is not None:
831
+ # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
832
+ fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
833
+ fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
834
+
835
+ if self.albedo is not None:
836
+ albedo = self.albedo.detach().cpu().numpy()
837
+ albedo = (albedo * 255).astype(np.uint8)
838
+ cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
839
+
840
+ if self.metallicRoughness is not None:
841
+ metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
842
+ metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
843
+ cv2.imwrite(metallic_path, metallicRoughness[..., 2])
844
+ cv2.imwrite(roughness_path, metallicRoughness[..., 1])
845
+
mesh_recon/models/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models = {}
2
+
3
+
4
+ def register(name):
5
+ def decorator(cls):
6
+ models[name] = cls
7
+ return cls
8
+ return decorator
9
+
10
+
11
+ def make(name, config):
12
+ model = models[name](config)
13
+ return model
14
+
15
+
16
+ from . import nerf, neus, geometry, texture
mesh_recon/models/base.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils.misc import get_rank
5
+
6
+ class BaseModel(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ self.config = config
10
+ self.rank = get_rank()
11
+ self.setup()
12
+ if self.config.get('weights', None):
13
+ self.load_state_dict(torch.load(self.config.weights))
14
+
15
+ def setup(self):
16
+ raise NotImplementedError
17
+
18
+ def update_step(self, epoch, global_step):
19
+ pass
20
+
21
+ def train(self, mode=True):
22
+ return super().train(mode=mode)
23
+
24
+ def eval(self):
25
+ return super().eval()
26
+
27
+ def regularizations(self, out):
28
+ return {}
29
+
30
+ @torch.no_grad()
31
+ def export(self, export_config):
32
+ return {}
mesh_recon/models/geometry.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from pytorch_lightning.utilities.rank_zero import rank_zero_info
7
+
8
+ import models
9
+ from models.base import BaseModel
10
+ from models.utils import scale_anything, get_activation, cleanup, chunk_batch
11
+ from models.network_utils import get_encoding, get_mlp, get_encoding_with_network
12
+ from utils.misc import get_rank
13
+ from systems.utils import update_module_step
14
+ from nerfacc import ContractionType
15
+
16
+
17
+ def contract_to_unisphere(x, radius, contraction_type):
18
+ if contraction_type == ContractionType.AABB:
19
+ x = scale_anything(x, (-radius, radius), (0, 1))
20
+ elif contraction_type == ContractionType.UN_BOUNDED_SPHERE:
21
+ x = scale_anything(x, (-radius, radius), (0, 1))
22
+ x = x * 2 - 1 # aabb is at [-1, 1]
23
+ mag = x.norm(dim=-1, keepdim=True)
24
+ mask = mag.squeeze(-1) > 1
25
+ x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
26
+ x = x / 4 + 0.5 # [-inf, inf] is at [0, 1]
27
+ else:
28
+ raise NotImplementedError
29
+ return x
30
+
31
+
32
+ class MarchingCubeHelper(nn.Module):
33
+ def __init__(self, resolution, use_torch=True):
34
+ super().__init__()
35
+ self.resolution = resolution
36
+ self.use_torch = use_torch
37
+ self.points_range = (0, 1)
38
+ if self.use_torch:
39
+ import torchmcubes
40
+ self.mc_func = torchmcubes.marching_cubes
41
+ else:
42
+ import mcubes
43
+ self.mc_func = mcubes.marching_cubes
44
+ self.verts = None
45
+
46
+ def grid_vertices(self):
47
+ if self.verts is None:
48
+ x, y, z = torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution)
49
+ x, y, z = torch.meshgrid(x, y, z, indexing='ij')
50
+ verts = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1).reshape(-1, 3)
51
+ self.verts = verts
52
+ return self.verts
53
+
54
+ def forward(self, level, threshold=0.):
55
+ level = level.float().view(self.resolution, self.resolution, self.resolution)
56
+ if self.use_torch:
57
+ verts, faces = self.mc_func(level.to(get_rank()), threshold)
58
+ verts, faces = verts.cpu(), faces.cpu().long()
59
+ else:
60
+ verts, faces = self.mc_func(-level.numpy(), threshold) # transform to numpy
61
+ verts, faces = torch.from_numpy(verts.astype(np.float32)), torch.from_numpy(faces.astype(np.int64)) # transform back to pytorch
62
+ verts = verts / (self.resolution - 1.)
63
+ return {
64
+ 'v_pos': verts,
65
+ 't_pos_idx': faces
66
+ }
67
+
68
+
69
+ class BaseImplicitGeometry(BaseModel):
70
+ def __init__(self, config):
71
+ super().__init__(config)
72
+ if self.config.isosurface is not None:
73
+ assert self.config.isosurface.method in ['mc', 'mc-torch']
74
+ if self.config.isosurface.method == 'mc-torch':
75
+ raise NotImplementedError("Please do not use mc-torch. It currently has some scaling issues I haven't fixed yet.")
76
+ self.helper = MarchingCubeHelper(self.config.isosurface.resolution, use_torch=self.config.isosurface.method=='mc-torch')
77
+ self.radius = self.config.radius
78
+ self.contraction_type = None # assigned in system
79
+
80
+ def forward_level(self, points):
81
+ raise NotImplementedError
82
+
83
+ def isosurface_(self, vmin, vmax):
84
+ def batch_func(x):
85
+ x = torch.stack([
86
+ scale_anything(x[...,0], (0, 1), (vmin[0], vmax[0])),
87
+ scale_anything(x[...,1], (0, 1), (vmin[1], vmax[1])),
88
+ scale_anything(x[...,2], (0, 1), (vmin[2], vmax[2])),
89
+ ], dim=-1).to(self.rank)
90
+ rv = self.forward_level(x).cpu()
91
+ cleanup()
92
+ return rv
93
+
94
+ level = chunk_batch(batch_func, self.config.isosurface.chunk, True, self.helper.grid_vertices())
95
+ mesh = self.helper(level, threshold=self.config.isosurface.threshold)
96
+ mesh['v_pos'] = torch.stack([
97
+ scale_anything(mesh['v_pos'][...,0], (0, 1), (vmin[0], vmax[0])),
98
+ scale_anything(mesh['v_pos'][...,1], (0, 1), (vmin[1], vmax[1])),
99
+ scale_anything(mesh['v_pos'][...,2], (0, 1), (vmin[2], vmax[2]))
100
+ ], dim=-1)
101
+ return mesh
102
+
103
+ @torch.no_grad()
104
+ def isosurface(self):
105
+ if self.config.isosurface is None:
106
+ raise NotImplementedError
107
+ mesh_coarse = self.isosurface_((-self.radius, -self.radius, -self.radius), (self.radius, self.radius, self.radius))
108
+ vmin, vmax = mesh_coarse['v_pos'].amin(dim=0), mesh_coarse['v_pos'].amax(dim=0)
109
+ vmin_ = (vmin - (vmax - vmin) * 0.1).clamp(-self.radius, self.radius)
110
+ vmax_ = (vmax + (vmax - vmin) * 0.1).clamp(-self.radius, self.radius)
111
+ mesh_fine = self.isosurface_(vmin_, vmax_)
112
+ return mesh_fine
113
+
114
+
115
+ @models.register('volume-density')
116
+ class VolumeDensity(BaseImplicitGeometry):
117
+ def setup(self):
118
+ self.n_input_dims = self.config.get('n_input_dims', 3)
119
+ self.n_output_dims = self.config.feature_dim
120
+ self.encoding_with_network = get_encoding_with_network(self.n_input_dims, self.n_output_dims, self.config.xyz_encoding_config, self.config.mlp_network_config)
121
+
122
+ def forward(self, points):
123
+ points = contract_to_unisphere(points, self.radius, self.contraction_type)
124
+ out = self.encoding_with_network(points.view(-1, self.n_input_dims)).view(*points.shape[:-1], self.n_output_dims).float()
125
+ density, feature = out[...,0], out
126
+ if 'density_activation' in self.config:
127
+ density = get_activation(self.config.density_activation)(density + float(self.config.density_bias))
128
+ if 'feature_activation' in self.config:
129
+ feature = get_activation(self.config.feature_activation)(feature)
130
+ return density, feature
131
+
132
+ def forward_level(self, points):
133
+ points = contract_to_unisphere(points, self.radius, self.contraction_type)
134
+ density = self.encoding_with_network(points.reshape(-1, self.n_input_dims)).reshape(*points.shape[:-1], self.n_output_dims)[...,0]
135
+ if 'density_activation' in self.config:
136
+ density = get_activation(self.config.density_activation)(density + float(self.config.density_bias))
137
+ return -density
138
+
139
+ def update_step(self, epoch, global_step):
140
+ update_module_step(self.encoding_with_network, epoch, global_step)
141
+
142
+
143
+ @models.register('volume-sdf')
144
+ class VolumeSDF(BaseImplicitGeometry):
145
+ def setup(self):
146
+ self.n_output_dims = self.config.feature_dim
147
+ encoding = get_encoding(3, self.config.xyz_encoding_config)
148
+ network = get_mlp(encoding.n_output_dims, self.n_output_dims, self.config.mlp_network_config)
149
+ self.encoding, self.network = encoding, network
150
+ self.grad_type = self.config.grad_type
151
+ self.finite_difference_eps = self.config.get('finite_difference_eps', 1e-3)
152
+ # the actual value used in training
153
+ # will update at certain steps if finite_difference_eps="progressive"
154
+ self._finite_difference_eps = None
155
+ if self.grad_type == 'finite_difference':
156
+ rank_zero_info(f"Using finite difference to compute gradients with eps={self.finite_difference_eps}")
157
+
158
+ def forward(self, points, with_grad=True, with_feature=True, with_laplace=False):
159
+ with torch.inference_mode(torch.is_inference_mode_enabled() and not (with_grad and self.grad_type == 'analytic')):
160
+ with torch.set_grad_enabled(self.training or (with_grad and self.grad_type == 'analytic')):
161
+ if with_grad and self.grad_type == 'analytic':
162
+ if not self.training:
163
+ points = points.clone() # points may be in inference mode, get a copy to enable grad
164
+ points.requires_grad_(True)
165
+
166
+ points_ = points # points in the original scale
167
+ points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1)
168
+
169
+ out = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims).float()
170
+ sdf, feature = out[...,0], out
171
+ if 'sdf_activation' in self.config:
172
+ sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias))
173
+ if 'feature_activation' in self.config:
174
+ feature = get_activation(self.config.feature_activation)(feature)
175
+ if with_grad:
176
+ if self.grad_type == 'analytic':
177
+ grad = torch.autograd.grad(
178
+ sdf, points_, grad_outputs=torch.ones_like(sdf),
179
+ create_graph=True, retain_graph=True, only_inputs=True
180
+ )[0]
181
+ elif self.grad_type == 'finite_difference':
182
+ eps = self._finite_difference_eps
183
+ offsets = torch.as_tensor(
184
+ [
185
+ [eps, 0.0, 0.0],
186
+ [-eps, 0.0, 0.0],
187
+ [0.0, eps, 0.0],
188
+ [0.0, -eps, 0.0],
189
+ [0.0, 0.0, eps],
190
+ [0.0, 0.0, -eps],
191
+ ]
192
+ ).to(points_)
193
+ points_d_ = (points_[...,None,:] + offsets).clamp(-self.radius, self.radius)
194
+ points_d = scale_anything(points_d_, (-self.radius, self.radius), (0, 1))
195
+ points_d_sdf = self.network(self.encoding(points_d.view(-1, 3)))[...,0].view(*points.shape[:-1], 6).float()
196
+ grad = 0.5 * (points_d_sdf[..., 0::2] - points_d_sdf[..., 1::2]) / eps
197
+
198
+ if with_laplace:
199
+ laplace = (points_d_sdf[..., 0::2] + points_d_sdf[..., 1::2] - 2 * sdf[..., None]).sum(-1) / (eps ** 2)
200
+
201
+ rv = [sdf]
202
+ if with_grad:
203
+ rv.append(grad)
204
+ if with_feature:
205
+ rv.append(feature)
206
+ if with_laplace:
207
+ assert self.config.grad_type == 'finite_difference', "Laplace computation is only supported with grad_type='finite_difference'"
208
+ rv.append(laplace)
209
+ rv = [v if self.training else v.detach() for v in rv]
210
+ return rv[0] if len(rv) == 1 else rv
211
+
212
+ def forward_level(self, points):
213
+ points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1)
214
+ sdf = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims)[...,0]
215
+ if 'sdf_activation' in self.config:
216
+ sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias))
217
+ return sdf
218
+
219
+ def update_step(self, epoch, global_step):
220
+ update_module_step(self.encoding, epoch, global_step)
221
+ update_module_step(self.network, epoch, global_step)
222
+ if self.grad_type == 'finite_difference':
223
+ if isinstance(self.finite_difference_eps, float):
224
+ self._finite_difference_eps = self.finite_difference_eps
225
+ elif self.finite_difference_eps == 'progressive':
226
+ hg_conf = self.config.xyz_encoding_config
227
+ assert hg_conf.otype == "ProgressiveBandHashGrid", "finite_difference_eps='progressive' only works with ProgressiveBandHashGrid"
228
+ current_level = min(
229
+ hg_conf.start_level + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
230
+ hg_conf.n_levels
231
+ )
232
+ grid_res = hg_conf.base_resolution * hg_conf.per_level_scale**(current_level - 1)
233
+ grid_size = 2 * self.config.radius / grid_res
234
+ if grid_size != self._finite_difference_eps:
235
+ rank_zero_info(f"Update finite_difference_eps to {grid_size}")
236
+ self._finite_difference_eps = grid_size
237
+ else:
238
+ raise ValueError(f"Unknown finite_difference_eps={self.finite_difference_eps}")
mesh_recon/models/nerf.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import models
8
+ from models.base import BaseModel
9
+ from models.utils import chunk_batch
10
+ from systems.utils import update_module_step
11
+ from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, accumulate_along_rays
12
+
13
+
14
+ @models.register('nerf')
15
+ class NeRFModel(BaseModel):
16
+ def setup(self):
17
+ self.geometry = models.make(self.config.geometry.name, self.config.geometry)
18
+ self.texture = models.make(self.config.texture.name, self.config.texture)
19
+ self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32))
20
+
21
+ if self.config.learned_background:
22
+ self.occupancy_grid_res = 256
23
+ self.near_plane, self.far_plane = 0.2, 1e4
24
+ self.cone_angle = 10**(math.log10(self.far_plane) / self.config.num_samples_per_ray) - 1. # approximate
25
+ self.render_step_size = 0.01 # render_step_size = max(distance_to_camera * self.cone_angle, self.render_step_size)
26
+ self.contraction_type = ContractionType.UN_BOUNDED_SPHERE
27
+ else:
28
+ self.occupancy_grid_res = 128
29
+ self.near_plane, self.far_plane = None, None
30
+ self.cone_angle = 0.0
31
+ self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray
32
+ self.contraction_type = ContractionType.AABB
33
+
34
+ self.geometry.contraction_type = self.contraction_type
35
+
36
+ if self.config.grid_prune:
37
+ self.occupancy_grid = OccupancyGrid(
38
+ roi_aabb=self.scene_aabb,
39
+ resolution=self.occupancy_grid_res,
40
+ contraction_type=self.contraction_type
41
+ )
42
+ self.randomized = self.config.randomized
43
+ self.background_color = None
44
+
45
+ def update_step(self, epoch, global_step):
46
+ update_module_step(self.geometry, epoch, global_step)
47
+ update_module_step(self.texture, epoch, global_step)
48
+
49
+ def occ_eval_fn(x):
50
+ density, _ = self.geometry(x)
51
+ # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size) based on taylor series
52
+ return density[...,None] * self.render_step_size
53
+
54
+ if self.training and self.config.grid_prune:
55
+ self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn)
56
+
57
+ def isosurface(self):
58
+ mesh = self.geometry.isosurface()
59
+ return mesh
60
+
61
+ def forward_(self, rays):
62
+ n_rays = rays.shape[0]
63
+ rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
64
+
65
+ def sigma_fn(t_starts, t_ends, ray_indices):
66
+ ray_indices = ray_indices.long()
67
+ t_origins = rays_o[ray_indices]
68
+ t_dirs = rays_d[ray_indices]
69
+ positions = t_origins + t_dirs * (t_starts + t_ends) / 2.
70
+ density, _ = self.geometry(positions)
71
+ return density[...,None]
72
+
73
+ def rgb_sigma_fn(t_starts, t_ends, ray_indices):
74
+ ray_indices = ray_indices.long()
75
+ t_origins = rays_o[ray_indices]
76
+ t_dirs = rays_d[ray_indices]
77
+ positions = t_origins + t_dirs * (t_starts + t_ends) / 2.
78
+ density, feature = self.geometry(positions)
79
+ rgb = self.texture(feature, t_dirs)
80
+ return rgb, density[...,None]
81
+
82
+ with torch.no_grad():
83
+ ray_indices, t_starts, t_ends = ray_marching(
84
+ rays_o, rays_d,
85
+ scene_aabb=None if self.config.learned_background else self.scene_aabb,
86
+ grid=self.occupancy_grid if self.config.grid_prune else None,
87
+ sigma_fn=sigma_fn,
88
+ near_plane=self.near_plane, far_plane=self.far_plane,
89
+ render_step_size=self.render_step_size,
90
+ stratified=self.randomized,
91
+ cone_angle=self.cone_angle,
92
+ alpha_thre=0.0
93
+ )
94
+
95
+ ray_indices = ray_indices.long()
96
+ t_origins = rays_o[ray_indices]
97
+ t_dirs = rays_d[ray_indices]
98
+ midpoints = (t_starts + t_ends) / 2.
99
+ positions = t_origins + t_dirs * midpoints
100
+ intervals = t_ends - t_starts
101
+
102
+ density, feature = self.geometry(positions)
103
+ rgb = self.texture(feature, t_dirs)
104
+
105
+ weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays)
106
+ opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays)
107
+ depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays)
108
+ comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays)
109
+ comp_rgb = comp_rgb + self.background_color * (1.0 - opacity)
110
+
111
+ out = {
112
+ 'comp_rgb': comp_rgb,
113
+ 'opacity': opacity,
114
+ 'depth': depth,
115
+ 'rays_valid': opacity > 0,
116
+ 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device)
117
+ }
118
+
119
+ if self.training:
120
+ out.update({
121
+ 'weights': weights.view(-1),
122
+ 'points': midpoints.view(-1),
123
+ 'intervals': intervals.view(-1),
124
+ 'ray_indices': ray_indices.view(-1)
125
+ })
126
+
127
+ return out
128
+
129
+ def forward(self, rays):
130
+ if self.training:
131
+ out = self.forward_(rays)
132
+ else:
133
+ out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays)
134
+ return {
135
+ **out,
136
+ }
137
+
138
+ def train(self, mode=True):
139
+ self.randomized = mode and self.config.randomized
140
+ return super().train(mode=mode)
141
+
142
+ def eval(self):
143
+ self.randomized = False
144
+ return super().eval()
145
+
146
+ def regularizations(self, out):
147
+ losses = {}
148
+ losses.update(self.geometry.regularizations(out))
149
+ losses.update(self.texture.regularizations(out))
150
+ return losses
151
+
152
+ @torch.no_grad()
153
+ def export(self, export_config):
154
+ mesh = self.isosurface()
155
+ if export_config.export_vertex_color:
156
+ _, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank))
157
+ viewdirs = torch.zeros(feature.shape[0], 3).to(feature)
158
+ viewdirs[...,2] = -1. # set the viewing directions to be -z (looking down)
159
+ rgb = self.texture(feature, viewdirs).clamp(0,1)
160
+ mesh['v_rgb'] = rgb.cpu()
161
+ return mesh
mesh_recon/models/network_utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import tinycudann as tcnn
7
+
8
+ from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info
9
+
10
+ from utils.misc import config_to_primitive, get_rank
11
+ from models.utils import get_activation
12
+ from systems.utils import update_module_step
13
+
14
+ class VanillaFrequency(nn.Module):
15
+ def __init__(self, in_channels, config):
16
+ super().__init__()
17
+ self.N_freqs = config['n_frequencies']
18
+ self.in_channels, self.n_input_dims = in_channels, in_channels
19
+ self.funcs = [torch.sin, torch.cos]
20
+ self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs)
21
+ self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs)
22
+ self.n_masking_step = config.get('n_masking_step', 0)
23
+ self.update_step(None, None) # mask should be updated at the beginning each step
24
+
25
+ def forward(self, x):
26
+ out = []
27
+ for freq, mask in zip(self.freq_bands, self.mask):
28
+ for func in self.funcs:
29
+ out += [func(freq*x) * mask]
30
+ return torch.cat(out, -1)
31
+
32
+ def update_step(self, epoch, global_step):
33
+ if self.n_masking_step <= 0 or global_step is None:
34
+ self.mask = torch.ones(self.N_freqs, dtype=torch.float32)
35
+ else:
36
+ self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2.
37
+ rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}')
38
+
39
+
40
+ class ProgressiveBandHashGrid(nn.Module):
41
+ def __init__(self, in_channels, config):
42
+ super().__init__()
43
+ self.n_input_dims = in_channels
44
+ encoding_config = config.copy()
45
+ encoding_config['otype'] = 'HashGrid'
46
+ with torch.cuda.device(get_rank()):
47
+ self.encoding = tcnn.Encoding(in_channels, encoding_config)
48
+ self.n_output_dims = self.encoding.n_output_dims
49
+ self.n_level = config['n_levels']
50
+ self.n_features_per_level = config['n_features_per_level']
51
+ self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps']
52
+ self.current_level = self.start_level
53
+ self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank())
54
+
55
+ def forward(self, x):
56
+ enc = self.encoding(x)
57
+ enc = enc * self.mask
58
+ return enc
59
+
60
+ def update_step(self, epoch, global_step):
61
+ current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level)
62
+ if current_level > self.current_level:
63
+ rank_zero_info(f'Update grid level to {current_level}')
64
+ self.current_level = current_level
65
+ self.mask[:self.current_level * self.n_features_per_level] = 1.
66
+
67
+
68
+ class CompositeEncoding(nn.Module):
69
+ def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.):
70
+ super(CompositeEncoding, self).__init__()
71
+ self.encoding = encoding
72
+ self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset
73
+ self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims
74
+
75
+ def forward(self, x, *args):
76
+ return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1)
77
+
78
+ def update_step(self, epoch, global_step):
79
+ update_module_step(self.encoding, epoch, global_step)
80
+
81
+
82
+ def get_encoding(n_input_dims, config):
83
+ # input suppose to be range [0, 1]
84
+ if config.otype == 'VanillaFrequency':
85
+ encoding = VanillaFrequency(n_input_dims, config_to_primitive(config))
86
+ elif config.otype == 'ProgressiveBandHashGrid':
87
+ encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config))
88
+ else:
89
+ with torch.cuda.device(get_rank()):
90
+ encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config))
91
+ encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.)
92
+ return encoding
93
+
94
+
95
+ class VanillaMLP(nn.Module):
96
+ def __init__(self, dim_in, dim_out, config):
97
+ super().__init__()
98
+ self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers']
99
+ self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False)
100
+ self.sphere_init_radius = config.get('sphere_init_radius', 0.5)
101
+ self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
102
+ for i in range(self.n_hidden_layers - 1):
103
+ self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
104
+ self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
105
+ self.layers = nn.Sequential(*self.layers)
106
+ self.output_activation = get_activation(config['output_activation'])
107
+
108
+ @torch.cuda.amp.autocast(False)
109
+ def forward(self, x):
110
+ x = self.layers(x.float())
111
+ x = self.output_activation(x)
112
+ return x
113
+
114
+ def make_linear(self, dim_in, dim_out, is_first, is_last):
115
+ layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
116
+ if self.sphere_init:
117
+ if is_last:
118
+ torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
119
+ torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
120
+ elif is_first:
121
+ torch.nn.init.constant_(layer.bias, 0.0)
122
+ torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
123
+ torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
124
+ else:
125
+ torch.nn.init.constant_(layer.bias, 0.0)
126
+ torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
127
+ else:
128
+ torch.nn.init.constant_(layer.bias, 0.0)
129
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
130
+
131
+ if self.weight_norm:
132
+ layer = nn.utils.weight_norm(layer)
133
+ return layer
134
+
135
+ def make_activation(self):
136
+ if self.sphere_init:
137
+ return nn.Softplus(beta=100)
138
+ else:
139
+ return nn.ReLU(inplace=True)
140
+
141
+
142
+ def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network):
143
+ rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.')
144
+ """
145
+ from https://github.com/NVlabs/tiny-cuda-nn/issues/96
146
+ It's the weight matrices of each layer laid out in row-major order and then concatenated.
147
+ Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP).
148
+ The padded input dimensions get a constant value of 1.0,
149
+ whereas the padded output dimensions are simply ignored,
150
+ so the weights pertaining to those can have any value.
151
+ """
152
+ padto = 16 if config.otype == 'FullyFusedMLP' else 8
153
+ n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto
154
+ n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto
155
+ data = list(network.parameters())[0].data
156
+ assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2
157
+ new_data = []
158
+ # first layer
159
+ weight = torch.zeros((config.n_neurons, n_input_dims)).to(data)
160
+ torch.nn.init.constant_(weight[:, 3:], 0.0)
161
+ torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons))
162
+ new_data.append(weight.flatten())
163
+ # hidden layers
164
+ for i in range(config.n_hidden_layers - 1):
165
+ weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data)
166
+ torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons))
167
+ new_data.append(weight.flatten())
168
+ # last layer
169
+ weight = torch.zeros((n_output_dims, config.n_neurons)).to(data)
170
+ torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001)
171
+ new_data.append(weight.flatten())
172
+ new_data = torch.cat(new_data)
173
+ data.copy_(new_data)
174
+
175
+
176
+ def get_mlp(n_input_dims, n_output_dims, config):
177
+ if config.otype == 'VanillaMLP':
178
+ network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config))
179
+ else:
180
+ with torch.cuda.device(get_rank()):
181
+ network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config))
182
+ if config.get('sphere_init', False):
183
+ sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network)
184
+ return network
185
+
186
+
187
+ class EncodingWithNetwork(nn.Module):
188
+ def __init__(self, encoding, network):
189
+ super().__init__()
190
+ self.encoding, self.network = encoding, network
191
+
192
+ def forward(self, x):
193
+ return self.network(self.encoding(x))
194
+
195
+ def update_step(self, epoch, global_step):
196
+ update_module_step(self.encoding, epoch, global_step)
197
+ update_module_step(self.network, epoch, global_step)
198
+
199
+
200
+ def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config):
201
+ # input suppose to be range [0, 1]
202
+ if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \
203
+ or network_config.otype in ['VanillaMLP']:
204
+ encoding = get_encoding(n_input_dims, encoding_config)
205
+ network = get_mlp(encoding.n_output_dims, n_output_dims, network_config)
206
+ encoding_with_network = EncodingWithNetwork(encoding, network)
207
+ else:
208
+ with torch.cuda.device(get_rank()):
209
+ encoding_with_network = tcnn.NetworkWithInputEncoding(
210
+ n_input_dims=n_input_dims,
211
+ n_output_dims=n_output_dims,
212
+ encoding_config=config_to_primitive(encoding_config),
213
+ network_config=config_to_primitive(network_config)
214
+ )
215
+ return encoding_with_network