fffan commited on
Commit
130118b
·
verified ·
1 Parent(s): 72d3318

Upload 461 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +220 -0
  2. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/ckpts/epoch=0-step=1000.ckpt +3 -0
  3. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/ckpts/last.ckpt +3 -0
  4. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/cmd.txt +2 -0
  5. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/.gitignore +7 -0
  6. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/.gitmodules +15 -0
  7. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/README.md +83 -0
  8. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/bear.gif +3 -0
  9. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/cherry.gif +3 -0
  10. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/teddy.png +3 -0
  11. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/check_output.py +42 -0
  12. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/.gitignore +196 -0
  13. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/.pre-commit-config.yaml +34 -0
  14. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/README.md +129 -0
  15. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/__init__.py +25 -0
  16. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/background/gaussian_mvdream_background.py +72 -0
  17. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting.yaml +96 -0
  18. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting_background.yaml +111 -0
  19. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting_mvdream.yaml +131 -0
  20. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting_shading.yaml +115 -0
  21. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting_zero123.yaml +144 -0
  22. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/scene_lang.yaml +138 -0
  23. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/exporter.py +44 -0
  24. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/gaussian_base.py +1469 -0
  25. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/gaussian_base.py.bak +1492 -0
  26. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/gaussian_dynamic.py +77 -0
  27. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/gaussian_io.py +327 -0
  28. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/mesh_utils.py +150 -0
  29. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/material/gaussian_material.py +116 -0
  30. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/diff_gaussian_rasterizer.py +151 -0
  31. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/diff_gaussian_rasterizer_advanced.py +152 -0
  32. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/diff_gaussian_rasterizer_background.py +145 -0
  33. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/diff_gaussian_rasterizer_shading.py +226 -0
  34. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/gaussian_batch_renderer.py +92 -0
  35. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/system/gaussian_mvdream.py +249 -0
  36. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/system/gaussian_splatting.py +223 -0
  37. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/system/gaussian_zero123.py +339 -0
  38. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/system/scene_lang.py +528 -0
  39. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/utils/__init__.py +0 -0
  40. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/utils/ae.py +63 -0
  41. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/utils/sam_clip.py +366 -0
  42. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/bear_background.png +3 -0
  43. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/bear_composite.png +3 -0
  44. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/bear_layers.png +0 -0
  45. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/boy_background.png +3 -0
  46. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/boy_composite.png +3 -0
  47. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/boy_layers.png +0 -0
  48. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/corgi_background.png +3 -0
  49. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/corgi_composite.png +3 -0
  50. 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/corgi_layers.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,223 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/bear.gif filter=lfs diff=lfs merge=lfs -text
37
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/cherry.gif filter=lfs diff=lfs merge=lfs -text
38
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/teddy.png filter=lfs diff=lfs merge=lfs -text
39
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/bear_background.png filter=lfs diff=lfs merge=lfs -text
40
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/bear_composite.png filter=lfs diff=lfs merge=lfs -text
41
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/boy_background.png filter=lfs diff=lfs merge=lfs -text
42
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/boy_composite.png filter=lfs diff=lfs merge=lfs -text
43
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/corgi_background.png filter=lfs diff=lfs merge=lfs -text
44
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/corgi_composite.png filter=lfs diff=lfs merge=lfs -text
45
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/stairs_background.png filter=lfs diff=lfs merge=lfs -text
46
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/stairs_composite.png filter=lfs diff=lfs merge=lfs -text
47
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/tinycudann-1.7.post70240121-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
48
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/outpaint_0.png filter=lfs diff=lfs merge=lfs -text
49
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/outpaint_1.png filter=lfs diff=lfs merge=lfs -text
50
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1000-val.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/0.png filter=lfs diff=lfs merge=lfs -text
53
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/1.png filter=lfs diff=lfs merge=lfs -text
54
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/10.png filter=lfs diff=lfs merge=lfs -text
55
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/100.png filter=lfs diff=lfs merge=lfs -text
56
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/101.png filter=lfs diff=lfs merge=lfs -text
57
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/102.png filter=lfs diff=lfs merge=lfs -text
58
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/103.png filter=lfs diff=lfs merge=lfs -text
59
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/104.png filter=lfs diff=lfs merge=lfs -text
60
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/105.png filter=lfs diff=lfs merge=lfs -text
61
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/106.png filter=lfs diff=lfs merge=lfs -text
62
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/107.png filter=lfs diff=lfs merge=lfs -text
63
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/108.png filter=lfs diff=lfs merge=lfs -text
64
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/109.png filter=lfs diff=lfs merge=lfs -text
65
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/11.png filter=lfs diff=lfs merge=lfs -text
66
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/110.png filter=lfs diff=lfs merge=lfs -text
67
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/111.png filter=lfs diff=lfs merge=lfs -text
68
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/112.png filter=lfs diff=lfs merge=lfs -text
69
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/113.png filter=lfs diff=lfs merge=lfs -text
70
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/114.png filter=lfs diff=lfs merge=lfs -text
71
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/115.png filter=lfs diff=lfs merge=lfs -text
72
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/116.png filter=lfs diff=lfs merge=lfs -text
73
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/117.png filter=lfs diff=lfs merge=lfs -text
74
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/118.png filter=lfs diff=lfs merge=lfs -text
75
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/119.png filter=lfs diff=lfs merge=lfs -text
76
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/12.png filter=lfs diff=lfs merge=lfs -text
77
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/13.png filter=lfs diff=lfs merge=lfs -text
78
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/14.png filter=lfs diff=lfs merge=lfs -text
79
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/15.png filter=lfs diff=lfs merge=lfs -text
80
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/16.png filter=lfs diff=lfs merge=lfs -text
81
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/17.png filter=lfs diff=lfs merge=lfs -text
82
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/18.png filter=lfs diff=lfs merge=lfs -text
83
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/19.png filter=lfs diff=lfs merge=lfs -text
84
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/2.png filter=lfs diff=lfs merge=lfs -text
85
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/20.png filter=lfs diff=lfs merge=lfs -text
86
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/21.png filter=lfs diff=lfs merge=lfs -text
87
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/22.png filter=lfs diff=lfs merge=lfs -text
88
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/23.png filter=lfs diff=lfs merge=lfs -text
89
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/24.png filter=lfs diff=lfs merge=lfs -text
90
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/25.png filter=lfs diff=lfs merge=lfs -text
91
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/26.png filter=lfs diff=lfs merge=lfs -text
92
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/27.png filter=lfs diff=lfs merge=lfs -text
93
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/28.png filter=lfs diff=lfs merge=lfs -text
94
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/29.png filter=lfs diff=lfs merge=lfs -text
95
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/3.png filter=lfs diff=lfs merge=lfs -text
96
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/30.png filter=lfs diff=lfs merge=lfs -text
97
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/31.png filter=lfs diff=lfs merge=lfs -text
98
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/32.png filter=lfs diff=lfs merge=lfs -text
99
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/33.png filter=lfs diff=lfs merge=lfs -text
100
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/34.png filter=lfs diff=lfs merge=lfs -text
101
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/35.png filter=lfs diff=lfs merge=lfs -text
102
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/36.png filter=lfs diff=lfs merge=lfs -text
103
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/37.png filter=lfs diff=lfs merge=lfs -text
104
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/38.png filter=lfs diff=lfs merge=lfs -text
105
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/39.png filter=lfs diff=lfs merge=lfs -text
106
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/4.png filter=lfs diff=lfs merge=lfs -text
107
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/40.png filter=lfs diff=lfs merge=lfs -text
108
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/41.png filter=lfs diff=lfs merge=lfs -text
109
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/42.png filter=lfs diff=lfs merge=lfs -text
110
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/43.png filter=lfs diff=lfs merge=lfs -text
111
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/44.png filter=lfs diff=lfs merge=lfs -text
112
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/45.png filter=lfs diff=lfs merge=lfs -text
113
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/46.png filter=lfs diff=lfs merge=lfs -text
114
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/47.png filter=lfs diff=lfs merge=lfs -text
115
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/48.png filter=lfs diff=lfs merge=lfs -text
116
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/49.png filter=lfs diff=lfs merge=lfs -text
117
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/5.png filter=lfs diff=lfs merge=lfs -text
118
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/50.png filter=lfs diff=lfs merge=lfs -text
119
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/51.png filter=lfs diff=lfs merge=lfs -text
120
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/52.png filter=lfs diff=lfs merge=lfs -text
121
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/53.png filter=lfs diff=lfs merge=lfs -text
122
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/54.png filter=lfs diff=lfs merge=lfs -text
123
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/55.png filter=lfs diff=lfs merge=lfs -text
124
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/56.png filter=lfs diff=lfs merge=lfs -text
125
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/57.png filter=lfs diff=lfs merge=lfs -text
126
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/58.png filter=lfs diff=lfs merge=lfs -text
127
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/59.png filter=lfs diff=lfs merge=lfs -text
128
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/6.png filter=lfs diff=lfs merge=lfs -text
129
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/60.png filter=lfs diff=lfs merge=lfs -text
130
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/61.png filter=lfs diff=lfs merge=lfs -text
131
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/62.png filter=lfs diff=lfs merge=lfs -text
132
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/63.png filter=lfs diff=lfs merge=lfs -text
133
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/64.png filter=lfs diff=lfs merge=lfs -text
134
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/65.png filter=lfs diff=lfs merge=lfs -text
135
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/66.png filter=lfs diff=lfs merge=lfs -text
136
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/67.png filter=lfs diff=lfs merge=lfs -text
137
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/68.png filter=lfs diff=lfs merge=lfs -text
138
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/69.png filter=lfs diff=lfs merge=lfs -text
139
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/7.png filter=lfs diff=lfs merge=lfs -text
140
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/70.png filter=lfs diff=lfs merge=lfs -text
141
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/71.png filter=lfs diff=lfs merge=lfs -text
142
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/72.png filter=lfs diff=lfs merge=lfs -text
143
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/73.png filter=lfs diff=lfs merge=lfs -text
144
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/74.png filter=lfs diff=lfs merge=lfs -text
145
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/75.png filter=lfs diff=lfs merge=lfs -text
146
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/76.png filter=lfs diff=lfs merge=lfs -text
147
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/77.png filter=lfs diff=lfs merge=lfs -text
148
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/78.png filter=lfs diff=lfs merge=lfs -text
149
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/79.png filter=lfs diff=lfs merge=lfs -text
150
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/8.png filter=lfs diff=lfs merge=lfs -text
151
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-test/80.png filter=lfs diff=lfs merge=lfs -text
152
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it1500-val.mp4 filter=lfs diff=lfs merge=lfs -text
153
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it500-val.mp4 filter=lfs diff=lfs merge=lfs -text
154
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat.mp4 filter=lfs diff=lfs merge=lfs -text
155
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/0.png filter=lfs diff=lfs merge=lfs -text
156
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/1.png filter=lfs diff=lfs merge=lfs -text
157
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/10.png filter=lfs diff=lfs merge=lfs -text
158
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/100.png filter=lfs diff=lfs merge=lfs -text
159
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/101.png filter=lfs diff=lfs merge=lfs -text
160
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/102.png filter=lfs diff=lfs merge=lfs -text
161
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/103.png filter=lfs diff=lfs merge=lfs -text
162
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/104.png filter=lfs diff=lfs merge=lfs -text
163
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/105.png filter=lfs diff=lfs merge=lfs -text
164
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/106.png filter=lfs diff=lfs merge=lfs -text
165
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/107.png filter=lfs diff=lfs merge=lfs -text
166
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/108.png filter=lfs diff=lfs merge=lfs -text
167
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/109.png filter=lfs diff=lfs merge=lfs -text
168
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/11.png filter=lfs diff=lfs merge=lfs -text
169
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/110.png filter=lfs diff=lfs merge=lfs -text
170
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/111.png filter=lfs diff=lfs merge=lfs -text
171
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/112.png filter=lfs diff=lfs merge=lfs -text
172
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/113.png filter=lfs diff=lfs merge=lfs -text
173
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/114.png filter=lfs diff=lfs merge=lfs -text
174
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/115.png filter=lfs diff=lfs merge=lfs -text
175
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/116.png filter=lfs diff=lfs merge=lfs -text
176
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/117.png filter=lfs diff=lfs merge=lfs -text
177
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/118.png filter=lfs diff=lfs merge=lfs -text
178
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/119.png filter=lfs diff=lfs merge=lfs -text
179
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/12.png filter=lfs diff=lfs merge=lfs -text
180
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/13.png filter=lfs diff=lfs merge=lfs -text
181
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/14.png filter=lfs diff=lfs merge=lfs -text
182
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/15.png filter=lfs diff=lfs merge=lfs -text
183
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/16.png filter=lfs diff=lfs merge=lfs -text
184
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/17.png filter=lfs diff=lfs merge=lfs -text
185
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/18.png filter=lfs diff=lfs merge=lfs -text
186
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/19.png filter=lfs diff=lfs merge=lfs -text
187
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/2.png filter=lfs diff=lfs merge=lfs -text
188
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/20.png filter=lfs diff=lfs merge=lfs -text
189
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/21.png filter=lfs diff=lfs merge=lfs -text
190
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/22.png filter=lfs diff=lfs merge=lfs -text
191
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/23.png filter=lfs diff=lfs merge=lfs -text
192
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/24.png filter=lfs diff=lfs merge=lfs -text
193
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/25.png filter=lfs diff=lfs merge=lfs -text
194
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/26.png filter=lfs diff=lfs merge=lfs -text
195
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/27.png filter=lfs diff=lfs merge=lfs -text
196
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/28.png filter=lfs diff=lfs merge=lfs -text
197
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/29.png filter=lfs diff=lfs merge=lfs -text
198
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/3.png filter=lfs diff=lfs merge=lfs -text
199
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/30.png filter=lfs diff=lfs merge=lfs -text
200
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/31.png filter=lfs diff=lfs merge=lfs -text
201
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/32.png filter=lfs diff=lfs merge=lfs -text
202
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/33.png filter=lfs diff=lfs merge=lfs -text
203
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/34.png filter=lfs diff=lfs merge=lfs -text
204
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/35.png filter=lfs diff=lfs merge=lfs -text
205
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/36.png filter=lfs diff=lfs merge=lfs -text
206
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/37.png filter=lfs diff=lfs merge=lfs -text
207
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/38.png filter=lfs diff=lfs merge=lfs -text
208
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/39.png filter=lfs diff=lfs merge=lfs -text
209
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/4.png filter=lfs diff=lfs merge=lfs -text
210
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/40.png filter=lfs diff=lfs merge=lfs -text
211
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/41.png filter=lfs diff=lfs merge=lfs -text
212
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/42.png filter=lfs diff=lfs merge=lfs -text
213
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/43.png filter=lfs diff=lfs merge=lfs -text
214
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/44.png filter=lfs diff=lfs merge=lfs -text
215
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/45.png filter=lfs diff=lfs merge=lfs -text
216
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/46.png filter=lfs diff=lfs merge=lfs -text
217
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/47.png filter=lfs diff=lfs merge=lfs -text
218
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/48.png filter=lfs diff=lfs merge=lfs -text
219
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/49.png filter=lfs diff=lfs merge=lfs -text
220
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/5.png filter=lfs diff=lfs merge=lfs -text
221
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/50.png filter=lfs diff=lfs merge=lfs -text
222
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/51.png filter=lfs diff=lfs merge=lfs -text
223
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/52.png filter=lfs diff=lfs merge=lfs -text
224
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/53.png filter=lfs diff=lfs merge=lfs -text
225
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/54.png filter=lfs diff=lfs merge=lfs -text
226
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/55.png filter=lfs diff=lfs merge=lfs -text
227
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/56.png filter=lfs diff=lfs merge=lfs -text
228
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/57.png filter=lfs diff=lfs merge=lfs -text
229
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/58.png filter=lfs diff=lfs merge=lfs -text
230
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/59.png filter=lfs diff=lfs merge=lfs -text
231
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/6.png filter=lfs diff=lfs merge=lfs -text
232
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/60.png filter=lfs diff=lfs merge=lfs -text
233
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/61.png filter=lfs diff=lfs merge=lfs -text
234
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/62.png filter=lfs diff=lfs merge=lfs -text
235
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/63.png filter=lfs diff=lfs merge=lfs -text
236
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/64.png filter=lfs diff=lfs merge=lfs -text
237
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/65.png filter=lfs diff=lfs merge=lfs -text
238
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/66.png filter=lfs diff=lfs merge=lfs -text
239
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/67.png filter=lfs diff=lfs merge=lfs -text
240
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/68.png filter=lfs diff=lfs merge=lfs -text
241
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/69.png filter=lfs diff=lfs merge=lfs -text
242
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/7.png filter=lfs diff=lfs merge=lfs -text
243
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/70.png filter=lfs diff=lfs merge=lfs -text
244
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/71.png filter=lfs diff=lfs merge=lfs -text
245
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/72.png filter=lfs diff=lfs merge=lfs -text
246
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/73.png filter=lfs diff=lfs merge=lfs -text
247
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/74.png filter=lfs diff=lfs merge=lfs -text
248
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/75.png filter=lfs diff=lfs merge=lfs -text
249
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/76.png filter=lfs diff=lfs merge=lfs -text
250
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/77.png filter=lfs diff=lfs merge=lfs -text
251
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/78.png filter=lfs diff=lfs merge=lfs -text
252
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/79.png filter=lfs diff=lfs merge=lfs -text
253
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/8.png filter=lfs diff=lfs merge=lfs -text
254
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/it800-feat/80.png filter=lfs diff=lfs merge=lfs -text
255
+ 000000000017.1/gs-sds-generation/3DitScene@20250207-015119/save/point_cloud.ply filter=lfs diff=lfs merge=lfs -text
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/ckpts/epoch=0-step=1000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05b53135d3920aa7616777d0b9040ed4a12f5060d43f700c361aae7805f9d248
3
+ size 28888900
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/ckpts/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2762c6bc7f087ebcf67488e4e6b106a8d7975d5143b26d3294b8a8f75b65a777
3
+ size 28888900
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/cmd.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python launch.py --config custom/threestudio-3dgs/configs/scene_lang.yaml --train --gpu 2 exp_root_dir=outputs/mira_video_clips/000000000/000000000017.1 tag=3DitScene system.geometry.geometry_convert_from=depth:/mnt/hdd1/wufan/datasets/MiraData/data/video_frames/000000000/000000000017.1/0.jpg system.geometry.ooi_bbox=[599,250,692,452] system.prompt_processor.prompt=It is night time in a city with tall buildings and neon lights illuminating the streets. system.empty_prompt= The background is a city at night with tall buildings, out of focus system.side_prompt= The background is a city at night with tall buildings, out of focus
2
+ Namespace(config='custom/threestudio-3dgs/configs/scene_lang.yaml', gpu='2', train=True, validate=False, test=False, export=False, save_dir=None, gradio=False, verbose=False, typecheck=False)
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ckpts/
2
+ outputs/
3
+ .threestudio_cache/
4
+
5
+ *.pyc
6
+ *.DS_Store
7
+
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/.gitmodules ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "submodules/MobileSAM-lang"]
2
+ path = submodules/MobileSAM-lang
3
+ url = https://github.com/zqh0253/MobileSAM-lang.git
4
+ [submodule "submodules/segment-anything-langsplat"]
5
+ path = submodules/segment-anything-langsplat
6
+ url = https://github.com/zqh0253/segment-anything-langsplat.git
7
+ [submodule "submodules/simple-knn"]
8
+ path = submodules/simple-knn
9
+ url = https://github.com/DSaurus/simple-knn.git
10
+ [submodule "submodules/diff-gaussian-rasterization"]
11
+ path = submodules/diff-gaussian-rasterization
12
+ url = https://github.com/zqh0253/diff-gaussian-rasterization-lang
13
+ [submodule "submodules/langsplat-rasterization"]
14
+ path = submodules/langsplat-rasterization
15
+ url = https://github.com/minghanqin/langsplat-rasterization.git
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 3DitScene: Editing Any Scene via Language-guided Disentangled Gaussian Splatting
2
+
3
+ [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://zqh0253.github.io/3DitScene/)
4
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/qihang/3Dit-Scene/)
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2405.18424-b31b1b.svg)](https://arxiv.org/abs/2405.18424)
6
+
7
+
8
+ <table class="center">
9
+ <tr style="line-height: 0">
10
+ <td width=35% style="border: none; text-align: center">Move the bear, and rotate the camera</td>
11
+ <td width=30% style="border: none; text-align: center">Move / remove the girl, and rotate the camera</td>
12
+ </tr>
13
+ <tr style="line-height: 0">
14
+ <td width=35% style="border: none"><img src="assets/bear.gif"></td>
15
+ <td width=30% style="border: none"><img src="assets/cherry.gif"></td>
16
+ </tr>
17
+ </table>
18
+
19
+ ## Installation
20
+
21
+ + Install `Python >= 3.8`.
22
+ + Install `torch >= 1.12`. We have tested on `torch==2.0.1+cu118`, but other versions should also work fine.
23
+ + Clone our repo:
24
+ ```
25
+ git clone https://github.com/zqh0253/3DitScene.git --recursive
26
+ ```
27
+ + Install dependencies:
28
+ ```
29
+ pip install -r requirements.txt
30
+ ```
31
+ + Install submodules:
32
+ ```
33
+ pip install ./submodules/segment-anything-langsplat
34
+ pip install ./submodules/MobileSAM-lang
35
+ pip install ./submodules/langsplat-rasterization
36
+ pip install ./submodules/simple-knn
37
+ ```
38
+ + Prepare weights for `SAM`:
39
+ ```
40
+ mkdir ckpts
41
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -O ./ckpts/sam_vit_h_4b8939.pth
42
+ cp submodules/MobileSAM-lang/weights/mobile_sam.pt ./ckpts/
43
+ ```
44
+
45
+ ## Usage
46
+
47
+ Run the following command to launch the optimization procedure:
48
+ ```
49
+ python -u launch.py --config custom/threestudio-3dgs/configs/scene_lang.yaml --train --gpu 0 tag=3DitScene
50
+ system.geometry.geometry_convert_from=depth:${IMGPATH} system.geometry.ooi_bbox=${BBOX}
51
+ system.prompt_processor.prompt="${PROMPT}" system.empty_prompt="${EMPTY_PROMPT}" system.side_prompt="${SIDE_PROMPT}"
52
+ ```
53
+ You should specify the image path `IMGPATH`, the bounding box of the interested object `BBOX`, and the promtps: `PROMPT`, `EMPTY_PROMPT`, `SIDE_PROMPT`. These prompts describe the image itself, the background area behind the image, and the content of the novel view region, respectively.
54
+
55
+ Here we provide an image (`./assets/teddy.png`) as example:
56
+ ```
57
+ python -u launch.py --config custom/threestudio-3dgs/configs/scene_lang.yaml --train --gpu 0 tag=3DitScene
58
+ system.geometry.geometry_convert_from=depth:assets/teddy.png system.geometry.ooi_bbox=[122,119,387,495]
59
+ system.prompt_processor.prompt="a teddy bear in Times Square" system.empty_prompt="Times Square, out of focus" system.side_prompt="Times Square, out of focus"
60
+ ```
61
+
62
+ ## Huggingface demo
63
+
64
+ We provide a huggingface demo. You have two options to explore our demo:
65
+ (1) Visit our [online Hugging Face space](https://huggingface.co/spaces/qihang/3Dit-Scene).
66
+ (2) Deploy it locally by following these steps:
67
+ + Install the necessary packages and download required files as specified in our [Dockerfile](https://huggingface.co/spaces/qihang/3Dit-Scene/blob/main/Dockerfile),
68
+ + Run the following command to launch the service at `localhost:10091`:
69
+ ```
70
+ python gradio_app_single_process.py --listen --hf-space --port 10091
71
+ ```
72
+
73
+ ## Citation
74
+
75
+ If you find our work useful, please consider citing:
76
+ ```
77
+ inproceedings{zhang20243DitScene,
78
+ author = {Qihang Zhang and Yinghao Xu and Chaoyang Wang and Hsin-Ying Lee and Gordon Wetzstein and Bolei Zhou and Ceyuan Yang},
79
+ title = {{3DitScene}: Editing Any Scene via Language-guided Disentangled Gaussian Splatting},
80
+ booktitle = {arXiv},
81
+ year = {2024}
82
+ }
83
+ ```
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/bear.gif ADDED

Git LFS Details

  • SHA256: ef075fb5f74ea8fc690b0b68d3abd88d151cdc246ac2df27767fc4cbb24227f9
  • Pointer size: 132 Bytes
  • Size of remote file: 6.43 MB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/cherry.gif ADDED

Git LFS Details

  • SHA256: a0069d4bdd8da45627cbba25f3c49a102220aff23cc8738417a13101d42a3b25
  • Pointer size: 132 Bytes
  • Size of remote file: 7.39 MB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/assets/teddy.png ADDED

Git LFS Details

  • SHA256: 6d73779e3f37a6e8e6171d30019a12851cdd1a69c5bdf2ff1c1b0b8ade8e1db6
  • Pointer size: 131 Bytes
  • Size of remote file: 390 kB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/check_output.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["MKL_THREADING_LAYER"] = "GNU"
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+ # Load the CSV file
9
+ csv_path = "/mnt/hdd1/wufan/datasets/MiraData/data/data_list/miradata_v1_9k_subset_shard_0.csv"
10
+ missing_csv_path = "/mnt/hdd1/wufan/datasets/MiraData/data/data_list/miradata_v1_9k_output_missing.csv"
11
+ df = pd.read_csv(csv_path)
12
+
13
+ output_path = "/mnt/hdd1/wufan/projects/3DitScene/outputs/mira_video_clips"
14
+ # Iterate through each row
15
+ save_dict = []
16
+ for index, row in df.iterrows():
17
+
18
+ # Construct the image path from file_path
19
+ # video_clips/000005007/000005007658.0.mp4
20
+ # '/mnt/hdd1/wufan/projects/3DitScene/outputs/mira_video_clips/000005007/000005007658.0/gs-sds-generation/save/it500-val.mp4'
21
+ file_path = row["file_path"].replace("video_clips/", "").replace(".mp4", "/gs-sds-generation")
22
+ file_dir = f"/mnt/hdd1/wufan/projects/3DitScene/outputs/mira_video_clips/{file_path}"
23
+
24
+ if os.path.exists(file_dir):
25
+ for item in os.listdir(file_dir):
26
+ file_path = os.path.join(file_dir, item)
27
+
28
+ # Check if 'it500-val.mp4' exists in the directory
29
+ if not os.path.exists(os.path.join(file_path, "save/it500-val.mp4")) and os.path.isfile(file_path) and os.path.getsize(file_path) > 0:
30
+ save_dict.append(row) # Append the current item if the file doesn't exist
31
+ else:
32
+ save_dict.append(row)
33
+
34
+
35
+ # Check if the image exists before proceeding
36
+
37
+
38
+ results_df = pd.DataFrame(save_dict)
39
+ # Save results to CSV
40
+ results_df.to_csv(missing_csv_path, index=False)
41
+
42
+
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/.gitignore ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ # End of https://www.toptal.com/developers/gitignore/api/python
177
+
178
+ .vscode/
179
+ .threestudio_cache/
180
+ outputs/
181
+ outputs-gradio/
182
+
183
+ # pretrained model weights
184
+ *.ckpt
185
+ *.pt
186
+ *.pth
187
+
188
+ # wandb
189
+ wandb/
190
+
191
+ custom/*
192
+
193
+ load/tets/256_tets.npz
194
+
195
+ diff-gaussian-rasterization/
196
+ simple-knn/
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/.pre-commit-config.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v4.4.0
7
+ hooks:
8
+ - id: trailing-whitespace
9
+ - id: check-ast
10
+ - id: check-merge-conflict
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: trailing-whitespace
14
+ args: [--markdown-linebreak-ext=md]
15
+
16
+ - repo: https://github.com/psf/black
17
+ rev: 23.3.0
18
+ hooks:
19
+ - id: black
20
+ language_version: python3
21
+
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ exclude: README.md
27
+ args: ["--profile", "black"]
28
+
29
+ # temporarily disable static type checking
30
+ # - repo: https://github.com/pre-commit/mirrors-mypy
31
+ # rev: v1.2.0
32
+ # hooks:
33
+ # - id: mypy
34
+ # args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"]
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # threestudio-3dgs
2
+ <img src="https://github.com/DSaurus/threestudio-3dgs/assets/24589363/55874a57-cff1-4e83-a006-76585bcd3b76" width="" height="128">
3
+
4
+ <img src="https://github.com/DSaurus/threestudio-3dgs/assets/24589363/505f35e5-b160-4c12-92dc-03482404ef5e" width="" height="128">
5
+
6
+ <img src="https://github.com/DSaurus/threestudio-3dgs/assets/24589363/a1041f0d-a56f-4f7f-adc3-1e99c2d81098" width="" height="128">
7
+
8
+ <img src="https://github.com/DSaurus/threestudio-3dgs/assets/24589363/f524524e-33aa-4701-9f0d-31cba23eaead" width="" height="128">
9
+
10
+ The Gaussian Splatting extension for threestudio. This extension is writen by [Ruizhi Shao](https://github.com/DSaurus) and [Youtian Lin](https://github.com/Linyou). To use it, please install [threestudio](https://github.com/threestudio-project/threestudio) first and then install this extension in threestudio `custom` directory.
11
+
12
+ ## Advanced Gaussian Splatting Installation (Recommend)
13
+ ```
14
+ cd custom
15
+ git clone https://github.com/DSaurus/threestudio-3dgs.git
16
+ cd threestudio-3dgs
17
+ git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization
18
+ git clone https://github.com/DSaurus/simple-knn.git
19
+ pip install ./diff-gaussian-rasterization
20
+ pip install ./simple-knn
21
+ ```
22
+
23
+ ## Native Gaussian Splatting Installation
24
+ ```
25
+ cd custom
26
+ git clone https://github.com/DSaurus/threestudio-3dgs.git
27
+ cd threestudio-3dgs
28
+ git clone [email protected]:graphdeco-inria/gaussian-splatting.git --recursive
29
+ cd gaussian-splatting/submodules
30
+ python -m pip install diff-gaussian-rasterization/.
31
+ python -m pip install simple-knn/
32
+
33
+ # If you want to export mesh, please install pymeshlab
34
+ pip install pymeshlab
35
+ ```
36
+
37
+
38
+ ## Quick Start
39
+ ```
40
+ # Native Gaussian Splatting + SDS Loss
41
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting.yaml --train --gpu 0 system.prompt_processor.prompt="a delicious hamburger"
42
+
43
+ # Advanced Gaussian Splatting with background + SDS Loss
44
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting_background.yaml --train --gpu 0 system.prompt_processor.prompt="a delicious hamburger"
45
+
46
+ # Advanced Gaussian Splatting with background and shading + SDS Loss
47
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting_shading.yaml --train --gpu 0 system.prompt_processor.prompt="a delicious hamburger"
48
+ ```
49
+
50
+ ## Gaussian Splatting + MVDream
51
+ Please first install [MVDream extension](https://github.com/DSaurus/threestudio-mvdream), then you can run the following script:
52
+ ```
53
+ # Advanced Gaussian Splatting with background and shading + MVDream
54
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting_mvdream.yaml --train --gpu 0 system.prompt_processor.prompt="an astronaut riding a horse"
55
+ ```
56
+
57
+ ## Gaussian Splatting + Zero-123
58
+ ```
59
+ # Advanced Gaussian Splatting + Zero-123
60
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting_zero123.yaml --train --gpu 0 data.image_path=./load/images/anya_front_rgba.png
61
+ ```
62
+
63
+ ## Resume from checkpoints
64
+ ```
65
+ # resume training from the last checkpoint, you may replace last.ckpt with any other checkpoints
66
+ python launch.py --config path/to/trial/dir/configs/parsed.yaml --train --gpu 0 resume=path/to/trial/dir/ckpts/last.ckpt
67
+ ```
68
+
69
+ ## Load from PLY
70
+ ```
71
+ # load from Gaussian Splatting ply file
72
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting.yaml --train --gpu 0 system.prompt_processor.prompt="a delicious hamburger" system.geometry.geometry_conver_from=path/to/poinc_cloud.ply
73
+
74
+ # only load points position and color from ply file
75
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting.yaml --train --gpu 0 system.prompt_processor.prompt="a delicious hamburger" system.geometry.geometry_conver_from=path/to/poinc_cloud.ply system.geometry.load_ply_only_vertex=true
76
+ ```
77
+
78
+ If you want to use shap-e initialization, please install [threestudio-shap-e extension](https://github.com/DSaurus/threestudio-shap-e) first.
79
+ ```
80
+ # load from shap-e initialization
81
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting.yaml --train --gpu 0 system.prompt_processor.prompt="a delicious hamburger" system.geometry.geometry_convert_from="shap-e:a delicious hamburger"
82
+ ```
83
+
84
+ If you want to use LRM initialization, please install [threestudio-lrm extension](https://github.com/Adamdad/threestudio-lrm) first.
85
+ ```
86
+ # load from lrm initialization
87
+ python launch.py --config custom/threestudio-3dgs/configs/gaussian_splatting.yaml --train --gpu 0 system.prompt_processor.prompt="a delicious hamburger" system.geometry.geometry_convert_from="lrm:a delicious hamburger"
88
+ ```
89
+
90
+ ## Export
91
+ You can use the following script to export Gaussian Splatting ply file and mesh obj.
92
+ ```
93
+ python launch.py --config path/to/config --export --gpu 0 system.prompt_processor.prompt="a delicious hamburger" resume=path/to/last.ckpt
94
+ ```
95
+
96
+ ## Citation
97
+ ```
98
+ @Article{kerbl3Dgaussians,
99
+ author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
100
+ title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
101
+ journal = {ACM Transactions on Graphics},
102
+ number = {4},
103
+ volume = {42},
104
+ month = {July},
105
+ year = {2023},
106
+ url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
107
+ }
108
+ ```
109
+
110
+ ## Acknowledgement
111
+ Please also consider citing these work about 3D Gaussian Splatting generation. Their open-source code inspires this project..
112
+
113
+ ```
114
+ @article{tang2023dreamgaussian,
115
+ title={DreamGaussian: Generative Gaussian Splatting for Efficient 3D Content Creation},
116
+ author={Tang, Jiaxiang and Ren, Jiawei and Zhou, Hang and Liu, Ziwei and Zeng, Gang},
117
+ journal={arXiv preprint arXiv:2309.16653},
118
+ year={2023}
119
+ }
120
+ ```
121
+
122
+ ```
123
+ @article{GaussianDreamer,
124
+ title={GaussianDreamer: Fast Generation from Text to 3D Gaussian Splatting with Point Cloud Priors},
125
+ author={Taoran Yi and Jiemin Fang and Guanjun Wu and Lingxi Xie and Xiaopeng Zhang and Wenyu Liu and Qi Tian and Xinggang Wang},
126
+ journal={arxiv:2310.08529},
127
+ year={2023}
128
+ }
129
+ ```
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threestudio
2
+ from packaging.version import Version
3
+
4
+ if hasattr(threestudio, "__version__") and Version(threestudio.__version__) >= Version(
5
+ "0.2.1"
6
+ ):
7
+ pass
8
+ else:
9
+ if hasattr(threestudio, "__version__"):
10
+ print(f"[INFO] threestudio version: {threestudio.__version__}")
11
+ raise ValueError(
12
+ "threestudio version must be >= 0.2.0, please update threestudio by pulling the latest version from github"
13
+ )
14
+
15
+
16
+ from .background import gaussian_mvdream_background
17
+ from .geometry import exporter, gaussian_base, gaussian_io
18
+ from .material import gaussian_material
19
+ from .renderer import (
20
+ diff_gaussian_rasterizer,
21
+ diff_gaussian_rasterizer_advanced,
22
+ diff_gaussian_rasterizer_background,
23
+ diff_gaussian_rasterizer_shading,
24
+ )
25
+ from .system import gaussian_mvdream, gaussian_splatting, gaussian_zero123, scene_lang
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/background/gaussian_mvdream_background.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+
4
+ import threestudio
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from threestudio.models.background.base import BaseBackground
9
+ from threestudio.models.networks import get_encoding, get_mlp
10
+ from threestudio.utils.ops import get_activation
11
+ from threestudio.utils.typing import *
12
+
13
+
14
+ @threestudio.register("gaussian-mvdream-neural-environment-map-background")
15
+ class NeuralEnvironmentMapBackground(BaseBackground):
16
+ @dataclass
17
+ class Config(BaseBackground.Config):
18
+ n_output_dims: int = 3
19
+ color_activation: str = "sigmoid"
20
+ dir_encoding_config: dict = field(
21
+ default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3}
22
+ )
23
+ mlp_network_config: dict = field(
24
+ default_factory=lambda: {
25
+ "otype": "VanillaMLP",
26
+ "activation": "ReLU",
27
+ "n_neurons": 16,
28
+ "n_hidden_layers": 2,
29
+ }
30
+ )
31
+ random_aug: bool = False
32
+ random_aug_prob: float = 0.5
33
+ eval_color: Optional[Tuple[float, float, float]] = None
34
+
35
+ # multi-view diffusion
36
+ share_aug_bg: bool = False
37
+
38
+ cfg: Config
39
+
40
+ def configure(self) -> None:
41
+ self.encoding = get_encoding(3, self.cfg.dir_encoding_config)
42
+ self.network = get_mlp(
43
+ self.encoding.n_output_dims,
44
+ self.cfg.n_output_dims,
45
+ self.cfg.mlp_network_config,
46
+ )
47
+
48
+ def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
49
+ if not self.training and self.cfg.eval_color is not None:
50
+ return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(
51
+ dirs
52
+ ) * torch.as_tensor(self.cfg.eval_color).to(dirs)
53
+ # viewdirs must be normalized before passing to this function
54
+ dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1)
55
+ dirs_embd = self.encoding(dirs.view(-1, 3))
56
+ color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims)
57
+ color = get_activation(self.cfg.color_activation)(color)
58
+ if (
59
+ self.training
60
+ and self.cfg.random_aug
61
+ and random.random() < self.cfg.random_aug_prob
62
+ ):
63
+ # use random background color with probability random_aug_prob
64
+ n_color = 1 if self.cfg.share_aug_bg else dirs.shape[0]
65
+ value = random.random() < 0.5
66
+ color = color * 0 + ( # prevent checking for unused parameters in DDP
67
+ torch.ones(n_color, 1, 1, self.cfg.n_output_dims)
68
+ .to(dirs)
69
+ .expand(*dirs.shape[:-1], -1)
70
+ * value
71
+ )
72
+ return color
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "gs-sds-generation"
2
+ tag: "${rmspace:${system.prompt_processor.prompt},_}"
3
+ exp_root_dir: "outputs"
4
+ seed: 0
5
+
6
+ data_type: "random-camera-datamodule"
7
+ data:
8
+ batch_size: 4
9
+ width: 512
10
+ height: 512
11
+ camera_distance_range: [2.5, 2.5]
12
+ fovy_range: [60, 70]
13
+ elevation_range: [-20, 90]
14
+ light_sample_strategy: "dreamfusion"
15
+ eval_camera_distance: 2.5
16
+ eval_fovy_deg: 70
17
+ rays_d_normalize: false
18
+
19
+ system_type: "gaussian-splatting-system"
20
+ system:
21
+
22
+ geometry_type: "gaussian-splatting"
23
+ geometry:
24
+ position_lr: [0, 0.001, 0.00002, 1000]
25
+ scale_lr: 0.005
26
+ feature_lr: 0.01
27
+ opacity_lr: 0.05
28
+ rotation_lr: 0.005
29
+ densification_interval: 300
30
+ prune_interval: 300
31
+ opacity_reset_interval: 50000000
32
+ densify_from_iter: 500
33
+ densify_until_iter: ${trainer.max_steps}
34
+ prune_from_iter: 500
35
+ prune_until_iter: ${trainer.max_steps}
36
+ densify_grad_threshold: 0.01
37
+ min_opac_prune: 0.005
38
+ split_thresh: 0.02
39
+ radii2d_thresh: 1000
40
+
41
+ init_num_pts: 4096
42
+ pc_init_radius: 0.8
43
+ opacity_init: 0.2
44
+
45
+ renderer_type: "diff-gaussian-rasterizer"
46
+ renderer:
47
+ debug: false
48
+ invert_bg_prob: 0.5
49
+
50
+ material_type: "no-material" # unused
51
+ material:
52
+ n_output_dims: 0
53
+
54
+ background_type: "solid-color-background" # unused
55
+
56
+ prompt_processor_type: "stable-diffusion-prompt-processor"
57
+ prompt_processor:
58
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
59
+ prompt: ???
60
+
61
+ guidance_type: "stable-diffusion-guidance"
62
+ guidance:
63
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
64
+ guidance_scale: 100.0
65
+ weighting_strategy: sds
66
+ min_step_percent: 0.02
67
+ max_step_percent: [1000, 0.98, 0.5, 1001]
68
+
69
+ exporter_type: "gaussian-mesh-exporter"
70
+
71
+ loggers:
72
+ wandb:
73
+ enable: false
74
+ project: 'threestudio'
75
+ name: None
76
+
77
+ loss:
78
+ lambda_sds: 0.1
79
+ lambda_position: 1.0
80
+ lambda_opacity: 0.0001
81
+ lambda_scales: 0.0001
82
+ lambda_tv_loss: 1.0
83
+ lambda_depth_tv_loss: 1.0
84
+
85
+ trainer:
86
+ max_steps: 5000
87
+ log_every_n_steps: 1
88
+ num_sanity_val_steps: 0
89
+ val_check_interval: 100
90
+ enable_progress_bar: true
91
+ precision: 32-true
92
+
93
+ checkpoint:
94
+ save_last: true # save at each validation time
95
+ save_top_k: -1
96
+ every_n_train_steps: ${trainer.max_steps}
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting_background.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "gs-sds-generation-background"
2
+ tag: "${rmspace:${system.prompt_processor.prompt},_}"
3
+ exp_root_dir: "outputs"
4
+ seed: 0
5
+
6
+ data_type: "random-camera-datamodule"
7
+ data:
8
+ batch_size: 4
9
+ width: 512
10
+ height: 512
11
+ camera_distance_range: [2.5, 2.5]
12
+ fovy_range: [60, 70]
13
+ elevation_range: [-20, 90]
14
+ light_sample_strategy: "dreamfusion"
15
+ eval_camera_distance: 2.5
16
+ eval_fovy_deg: 70
17
+ rays_d_normalize: false
18
+
19
+ system_type: "gaussian-splatting-system"
20
+ system:
21
+
22
+ geometry_type: "gaussian-splatting"
23
+ geometry:
24
+ position_lr: [0, 0.001, 0.00002, 1000]
25
+ scale_lr: 0.005
26
+ feature_lr: 0.01
27
+ opacity_lr: 0.05
28
+ rotation_lr: 0.005
29
+ densification_interval: 300
30
+ prune_interval: 300
31
+ opacity_reset_interval: 50000000
32
+ densify_from_iter: 500
33
+ densify_until_iter: 10000
34
+ prune_from_iter: 500
35
+ prune_until_iter: ${trainer.max_steps}
36
+ densify_grad_threshold: 0.01
37
+ min_opac_prune: 0.005
38
+ split_thresh: 0.02
39
+ radii2d_thresh: 1000
40
+
41
+ init_num_pts: 4096
42
+ pc_init_radius: 0.8
43
+ opacity_init: 0.2
44
+
45
+ renderer_type: "diff-gaussian-rasterizer-background"
46
+ renderer:
47
+ debug: false
48
+
49
+ material_type: "no-material" # unused
50
+ material:
51
+ n_output_dims: 0
52
+
53
+ background_type: "gaussian-mvdream-neural-environment-map-background"
54
+ background:
55
+ color_activation: sigmoid
56
+ random_aug: true
57
+ random_aug_prob: 0.8
58
+
59
+ prompt_processor_type: "stable-diffusion-prompt-processor"
60
+ prompt_processor:
61
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
62
+ prompt: ???
63
+
64
+ guidance_type: "stable-diffusion-guidance"
65
+ guidance:
66
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
67
+ guidance_scale: 100.0
68
+ weighting_strategy: sds
69
+ min_step_percent: 0.02
70
+ max_step_percent: [1000, 0.98, 0.5, 1001]
71
+
72
+ exporter_type: "gaussian-mesh-exporter"
73
+
74
+ loggers:
75
+ wandb:
76
+ enable: false
77
+ project: 'threestudio'
78
+ name: None
79
+
80
+ loss:
81
+ lambda_sds: 0.1
82
+ lambda_position: 1.0
83
+ lambda_opacity: 0.0001
84
+ lambda_scales: 0.0001
85
+ lambda_tv_loss: 1.0
86
+ lambda_depth_tv_loss: 1.0
87
+
88
+ optimizer:
89
+ name: Adam
90
+ args:
91
+ lr: 0.01
92
+ betas: [0.9, 0.99]
93
+ eps: 1.e-15
94
+ params:
95
+ background:
96
+ lr: 0.001
97
+
98
+
99
+
100
+ trainer:
101
+ max_steps: 5000
102
+ log_every_n_steps: 1
103
+ num_sanity_val_steps: 0
104
+ val_check_interval: 100
105
+ enable_progress_bar: true
106
+ precision: 32-true
107
+
108
+ checkpoint:
109
+ save_last: true # save at each validation time
110
+ save_top_k: -1
111
+ every_n_train_steps: ${trainer.max_steps}
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting_mvdream.yaml ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "gs-sds-mvdream"
2
+ tag: "${rmspace:${system.prompt_processor.prompt},_}"
3
+ exp_root_dir: "outputs"
4
+ seed: 0
5
+
6
+ data_type: "mvdream-random-multiview-camera-datamodule"
7
+ data:
8
+ batch_size: [4,4]
9
+ n_view: 4
10
+ # 0-4999: 64x64, >=5000: 256x256
11
+ width: [256, 256]
12
+ height: [256, 256]
13
+ resolution_milestones: [1000]
14
+ camera_distance_range: [0.8, 1.0] # relative
15
+ fovy_range: [15, 60]
16
+ elevation_range: [0, 30]
17
+ camera_perturb: 0.
18
+ center_perturb: 0.
19
+ up_perturb: 0.
20
+ n_val_views: 4
21
+ eval_camera_distance: 3.0
22
+ eval_fovy_deg: 40.
23
+ rays_d_normalize: false
24
+
25
+ system_type: "gaussian-splatting-mvdream-system"
26
+ system:
27
+ geometry_type: "gaussian-splatting"
28
+ geometry:
29
+ position_lr: [0, 0.0001, 0.00001, 1500]
30
+ scale_lr: [0, 0.01, 0.001, 1500]
31
+ feature_lr: [0, 0.005, 0.001, 6000]
32
+ opacity_lr: 0.05
33
+ rotation_lr: 0.001
34
+ pred_normal: false
35
+ normal_lr: 0.005
36
+ densification_interval: 300
37
+ prune_interval: 300
38
+ opacity_reset_interval: 100000
39
+ densify_from_iter: 1500
40
+ densify_until_iter: ${trainer.max_steps}
41
+ prune_from_iter: 1500
42
+ prune_until_iter: ${trainer.max_steps}
43
+ densify_grad_threshold: 0.01
44
+ min_opac_prune: 0.01
45
+ split_thresh: 0.02
46
+ radii2d_thresh: 1000
47
+
48
+ sphere: False
49
+ color_clip: [0, 0.01, 0.02, 1500, 0.5, 4000, 1.0, 7000]
50
+
51
+ init_num_pts: 4096
52
+ pc_init_radius: 0.5
53
+ opacity_init: 0.05
54
+ max_num: 100000
55
+
56
+ renderer_type: "diff-gaussian-rasterizer-shading"
57
+ renderer:
58
+ debug: false
59
+
60
+ material_type: "gaussian-diffuse-with-point-light-material"
61
+ material:
62
+ ambient_only_steps: 3000
63
+ textureless_prob: 0.0
64
+ ambient_light_color: [0.9, 0.9, 0.9]
65
+ diffuse_light_color: [0.1, 0.1, 0.1]
66
+ soft_shading: true
67
+
68
+ background_type: "gaussian-mvdream-neural-environment-map-background"
69
+ background:
70
+ color_activation: sigmoid
71
+ random_aug: true
72
+ share_aug_bg: true
73
+ random_aug_prob: 0.95
74
+
75
+ prompt_processor_type: "stable-diffusion-prompt-processor"
76
+ prompt_processor:
77
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
78
+ prompt: ???
79
+ negative_prompt: "ugly, bad anatomy, blurry, pixelated obscure, unnatural colors, poor lighting, dull, and unclear, cropped, lowres, low quality, artifacts, duplicate, morbid, mutilated, poorly drawn face, deformed, dehydrated, bad proportions"
80
+ front_threshold: 30.
81
+ back_threshold: 30.
82
+
83
+ guidance_type: "mvdream-multiview-diffusion-guidance"
84
+ guidance:
85
+ model_name: "sd-v2.1-base-4view"
86
+ ckpt_path: null # path to a pre-downloaded checkpoint file (null for loading from URL)
87
+ guidance_scale: 50.0
88
+ min_step_percent: [0, 0.98, 0.02, 7000] # (start_iter, start_val, end_val, end_iter)
89
+ max_step_percent: [0, 0.98, 0.50, 7000]
90
+ recon_loss: true
91
+ recon_std_rescale: 0.5
92
+
93
+ exporter_type: "gaussian-mesh-exporter"
94
+
95
+ loggers:
96
+ wandb:
97
+ enable: false
98
+ project: 'threestudio'
99
+ name: None
100
+
101
+ loss:
102
+ lambda_sds: 0.1
103
+ lambda_position: 1.0
104
+ lambda_opacity: 0.0001
105
+ lambda_scales: 0.0001
106
+ lambda_sparsity: 1.0
107
+ lambda_tv_loss: 0.0
108
+ lambda_depth_tv_loss: 1.0
109
+
110
+ optimizer:
111
+ name: Adam
112
+ args:
113
+ lr: 0.01
114
+ betas: [0.9, 0.99]
115
+ eps: 1.e-6
116
+ params:
117
+ background:
118
+ lr: 0.0001
119
+
120
+ trainer:
121
+ max_steps: 10000
122
+ log_every_n_steps: 1
123
+ num_sanity_val_steps: 0
124
+ val_check_interval: 100
125
+ enable_progress_bar: true
126
+ precision: 32-true
127
+
128
+ checkpoint:
129
+ save_last: true # save at each validation time
130
+ save_top_k: -1
131
+ every_n_train_steps: ${trainer.max_steps}
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting_shading.yaml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "gs-sds-generation-shading"
2
+ tag: "${rmspace:${system.prompt_processor.prompt},_}"
3
+ exp_root_dir: "outputs"
4
+ seed: 0
5
+
6
+ data_type: "random-camera-datamodule"
7
+ data:
8
+ batch_size: 4
9
+ width: 512
10
+ height: 512
11
+ camera_distance_range: [2.5, 2.5]
12
+ fovy_range: [60, 70]
13
+ elevation_range: [-20, 90]
14
+ light_sample_strategy: "dreamfusion"
15
+ eval_camera_distance: 2.5
16
+ eval_fovy_deg: 70
17
+ rays_d_normalize: false
18
+
19
+ system_type: "gaussian-splatting-system"
20
+ system:
21
+
22
+ geometry_type: "gaussian-splatting"
23
+ geometry:
24
+ position_lr: [0, 0.001, 0.00002, 1000]
25
+ scale_lr: 0.005
26
+ feature_lr: 0.01
27
+ opacity_lr: 0.05
28
+ rotation_lr: 0.005
29
+ densification_interval: 300
30
+ prune_interval: 300
31
+ opacity_reset_interval: 50000000
32
+ densify_from_iter: 500
33
+ densify_until_iter: ${trainer.max_steps}
34
+ prune_from_iter: 500
35
+ prune_until_iter: ${trainer.max_steps}
36
+ densify_grad_threshold: 0.01
37
+ min_opac_prune: 0.005
38
+ split_thresh: 0.02
39
+ radii2d_thresh: 1000
40
+
41
+ init_num_pts: 4096
42
+ pc_init_radius: 0.8
43
+ opacity_init: 0.2
44
+
45
+ renderer_type: "diff-gaussian-rasterizer-shading"
46
+ renderer:
47
+ debug: false
48
+
49
+ material_type: "gaussian-diffuse-with-point-light-material"
50
+ material:
51
+ ambient_only_steps: 2000
52
+ textureless_prob: 0.0
53
+ ambient_light_color: [1.0, 1.0, 1.0]
54
+ diffuse_light_color: [0.0, 0.0, 0.0]
55
+ soft_shading: true
56
+
57
+ background_type: "gaussian-mvdream-neural-environment-map-background"
58
+ background:
59
+ color_activation: sigmoid
60
+ random_aug: true
61
+ random_aug_prob: 0.8
62
+
63
+ prompt_processor_type: "stable-diffusion-prompt-processor"
64
+ prompt_processor:
65
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
66
+ prompt: ???
67
+
68
+ guidance_type: "stable-diffusion-guidance"
69
+ guidance:
70
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
71
+ guidance_scale: 100.0
72
+ weighting_strategy: sds
73
+ min_step_percent: 0.02
74
+ max_step_percent: [2000, 0.98, 0.5, 2001]
75
+
76
+ exporter_type: "gaussian-mesh-exporter"
77
+
78
+ loggers:
79
+ wandb:
80
+ enable: false
81
+ project: 'threestudio'
82
+ name: None
83
+
84
+ loss:
85
+ lambda_sds: 0.1
86
+ lambda_position: 1.0
87
+ lambda_opacity: 0.0001
88
+ lambda_scales: 0.0001
89
+ lambda_tv_loss: 1.0
90
+ lambda_depth_tv_loss: 1.0
91
+
92
+ optimizer:
93
+ name: Adam
94
+ args:
95
+ lr: 0.01
96
+ betas: [0.9, 0.99]
97
+ eps: 1.e-15
98
+ params:
99
+ background:
100
+ lr: 0.001
101
+
102
+
103
+
104
+ trainer:
105
+ max_steps: 5000
106
+ log_every_n_steps: 1
107
+ num_sanity_val_steps: 0
108
+ val_check_interval: 100
109
+ enable_progress_bar: true
110
+ precision: 32-true
111
+
112
+ checkpoint:
113
+ save_last: true # save at each validation time
114
+ save_top_k: -1
115
+ every_n_train_steps: ${trainer.max_steps}
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/gaussian_splatting_zero123.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "gs-zero123-sai"
2
+ tag: "${data.random_camera.height}_${rmspace:${basename:${data.image_path}},_}"
3
+ exp_root_dir: "outputs"
4
+ seed: 0
5
+
6
+ data_type: "single-image-datamodule"
7
+ data: # threestudio/data/image.py -> SingleImageDataModuleConfig
8
+ image_path: ./load/images/hamburger_rgba.png
9
+ height: [128, 256, 512]
10
+ width: [128, 256, 512]
11
+ resolution_milestones: [200, 300]
12
+ default_elevation_deg: 5.0
13
+ default_azimuth_deg: 0.0
14
+ default_camera_distance: 3.8
15
+ default_fovy_deg: 20.0
16
+ requires_depth: ${cmaxgt0orcmaxgt0:${system.loss.lambda_depth},${system.loss.lambda_depth_rel}}
17
+ requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
18
+ random_camera: # threestudio/data/uncond.py -> RandomCameraDataModuleConfig
19
+ height: 256
20
+ width: 256
21
+ batch_size: 4
22
+ eval_height: 512
23
+ eval_width: 512
24
+ eval_batch_size: 1
25
+ elevation_range: [-10, 80]
26
+ azimuth_range: [-180, 180]
27
+ camera_distance_range: [3.8, 3.8]
28
+ fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
29
+ progressive_until: 0
30
+ camera_perturb: 0.0
31
+ center_perturb: 0.0
32
+ up_perturb: 0.0
33
+ light_position_perturb: 1.0
34
+ light_distance_range: [7.5, 10.0]
35
+ eval_elevation_deg: ${data.default_elevation_deg}
36
+ eval_camera_distance: ${data.default_camera_distance}
37
+ eval_fovy_deg: ${data.default_fovy_deg}
38
+ light_sample_strategy: "dreamfusion"
39
+ batch_uniform_azimuth: False
40
+ n_val_views: 30
41
+ n_test_views: 120
42
+
43
+ system_type: "gaussian-splatting-zero123-system"
44
+ system:
45
+ geometry_type: "gaussian-splatting"
46
+ geometry:
47
+ position_lr: [0, 0.001, 0.00002, 1000]
48
+ scale_lr: [0, 0.01, 0.001, 1000]
49
+ feature_lr: 0.01
50
+ opacity_lr: 0.05
51
+ rotation_lr: 0.001
52
+ densification_interval: 100
53
+ prune_interval: 100
54
+ opacity_reset_interval: 100000
55
+ densify_from_iter: 0
56
+ densify_until_iter: ${trainer.max_steps}
57
+ prune_from_iter: 0
58
+ prune_until_iter: ${trainer.max_steps}
59
+ densify_grad_threshold: 0.01
60
+ min_opac_prune: 0.005
61
+ split_thresh: 0.02
62
+ radii2d_thresh: 1000
63
+
64
+ sphere: False
65
+
66
+ init_num_pts: 4096
67
+ pc_init_radius: 0.5
68
+ opacity_init: 0.05
69
+ max_num: 500000
70
+
71
+ exporter_type: "gaussian-mesh-exporter"
72
+
73
+ renderer_type: "diff-gaussian-rasterizer-advanced"
74
+ renderer:
75
+ debug: false
76
+ invert_bg_prob: 1.0
77
+
78
+ material_type: "no-material" # unused
79
+ material:
80
+ n_output_dims: 0
81
+
82
+ background_type: "solid-color-background" # unused
83
+
84
+ prompt_processor_type: "dummy-prompt-processor" # Zero123 doesn't use prompts
85
+ prompt_processor:
86
+ pretrained_model_name_or_path: ""
87
+ prompt: ""
88
+
89
+ guidance_type: "stable-zero123-guidance"
90
+ guidance:
91
+ pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
92
+ pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
93
+ vram_O: ${not:${gt0:${system.freq.guidance_eval}}}
94
+ cond_image_path: ${data.image_path}
95
+ cond_elevation_deg: ${data.default_elevation_deg}
96
+ cond_azimuth_deg: ${data.default_azimuth_deg}
97
+ cond_camera_distance: ${data.default_camera_distance}
98
+ guidance_scale: 3.0
99
+ min_step_percent: [50, 0.7, 0.3, 200] # (start_iter, start_val, end_val, end_iter)
100
+ max_step_percent: [50, 0.98, 0.8, 200]
101
+
102
+ freq:
103
+ ref_only_steps: 0
104
+ guidance_eval: 0
105
+
106
+ loggers:
107
+ wandb:
108
+ enable: false
109
+ project: "threestudio"
110
+ name: None
111
+
112
+ loss:
113
+ lambda_sds: 0.1
114
+ lambda_rgb: [100, 500., 1000., 400]
115
+ lambda_mask: 50.
116
+ lambda_depth: 0. # 0.05
117
+ lambda_depth_rel: 0. # [0, 0, 0.05, 100]
118
+ lambda_normal: 0. # [0, 0, 0.05, 100]
119
+ lambda_normal_smooth: 0.
120
+ lambda_3d_normal_smooth: 0.
121
+
122
+ optimizer:
123
+ name: Adam
124
+ args:
125
+ lr: 0.01
126
+ betas: [0.9, 0.99]
127
+ eps: 1.e-8
128
+ params:
129
+ background:
130
+ lr: 0.001
131
+
132
+
133
+ trainer:
134
+ max_steps: 5000
135
+ log_every_n_steps: 1
136
+ num_sanity_val_steps: 0
137
+ val_check_interval: 100
138
+ enable_progress_bar: true
139
+ precision: 32
140
+
141
+ checkpoint:
142
+ save_last: true # save at each validation time
143
+ save_top_k: -1
144
+ every_n_train_steps: 100 # ${trainer.max_steps}
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/configs/scene_lang.yaml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "gs-sds-generation"
2
+ tag: "${rmspace:${system.prompt_processor.prompt},_}"
3
+ exp_root_dir: "outputs"
4
+ seed: 0
5
+ # resume: "/mnt/hdd1/wufan/projects/3DitScene/outputs/gs-sds-generation/3DitScene@20250204-120500/ckpts/last.ckpt"
6
+
7
+ data_type: "random-camera-datamodule"
8
+ data:
9
+ rotate_traj: false # WU
10
+ random_traj: false # WU
11
+ batch_size: 1
12
+ width: 512
13
+ height: 512
14
+ camera_distance_range: [2.5, 2.5]
15
+ fovy_range: [60, 60]
16
+ elevation_range: [0, 0] # The vertical angle of the camera relative to the object, in degrees.
17
+ light_sample_strategy: "dreamfusion"
18
+ eval_camera_distance: 2.5
19
+ eval_fovy_deg: 60 # The field of view (FOV) in the vertical direction, in degrees.
20
+ eval_elevation_deg: 0
21
+ rays_d_normalize: false
22
+ center_perturb: 0
23
+ up_perturb: 0
24
+ camera_perturb: 0
25
+ azimuth_range: [-15, 15] # The range of horizontal rotation angles during training
26
+ val_azimuth_range: [-15, 15] # The range of horizontal rotation angles during validation
27
+ insert_zero: true
28
+
29
+ system_type: "scene-lang-system"
30
+ system:
31
+ encoder_hidden_dims: [256, 128, 32, 3]
32
+ decoder_hidden_dims: [32, 128, 256, 512]
33
+ xyz_noise_ratio: [1000, 0.0, 0.0, 3000]
34
+ drop_ooi_ratio: 0.3
35
+ crop_with_lang: true
36
+ densify: false
37
+
38
+ geometry_type: "gaussian-splatting"
39
+ geometry:
40
+ ooi_bbox: [360,370,730,590]
41
+ geometry_convert_from: depth:assets/anime.png
42
+ position_lr: [0, 0.001, 0.00002, 1000]
43
+ scaling_lr: 0.05
44
+ feature_lr: 0.01
45
+ opacity_lr: 0.05
46
+ rotation_lr: 0.005
47
+ lang_lr: 0.0003
48
+ densification_interval: 300
49
+ prune_interval: 300
50
+ opacity_reset_interval: 50000000
51
+ densify_from_iter: 500
52
+ densify_until_iter: ${trainer.max_steps}
53
+ prune_from_iter: 500
54
+ prune_until_iter: ${trainer.max_steps}
55
+ densify_grad_threshold: 0.01
56
+ min_opac_prune: 0.005
57
+ split_thresh: 0.02
58
+ radii2d_thresh: 1000
59
+
60
+ init_num_pts: 4096
61
+ pc_init_radius: 0.8
62
+ opacity_init: 0.2
63
+
64
+ empty_prompt: ${system.empty_prompt}
65
+ prompt: ${system.prompt_processor.prompt}
66
+ max_scaling: 0.2
67
+
68
+ renderer_type: "diff-gaussian-rasterizer"
69
+ renderer:
70
+ debug: false
71
+ invert_bg_prob: 0.5
72
+
73
+ material_type: "no-material" # unused
74
+ material:
75
+ n_output_dims: 0
76
+
77
+ background_type: "solid-color-background" # unused
78
+
79
+ prompt_processor_type: "stable-diffusion-prompt-processor"
80
+ prompt_processor:
81
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
82
+ # pretrained_model_name_or_path: "/mnt/petrelfs/zhangqihang/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1-base"
83
+ prompt: ???
84
+ empty_prompt: "empty"
85
+
86
+ guidance_type: "stable-diffusion-guidance"
87
+ guidance:
88
+ pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
89
+ # pretrained_model_name_or_path: "/mnt/petrelfs/zhangqihang/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1-base"
90
+ guidance_scale: 5.0
91
+ weighting_strategy: sds
92
+ min_step_percent: 0.02
93
+ max_step_percent: [0, 0.5, 0.1, 1000]
94
+ csd: false
95
+
96
+ # guidance_type: "stable-diffusion-vsd-guidance"
97
+ # guidance:
98
+ # pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
99
+ # pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1"
100
+ # guidance_scale: 7.5
101
+ # min_step_percent: 0.02
102
+
103
+ exporter_type: "gaussian-mesh-exporter"
104
+
105
+ sam_clip:
106
+ use_mobile_sam: True
107
+
108
+ loggers:
109
+ wandb:
110
+ enable: false
111
+ project: '3ditscene'
112
+ name: "${tag}"
113
+
114
+ loss:
115
+ lambda_sds: 0.01
116
+ lambda_ref: 1000
117
+ lambda_depth: 0.0
118
+ lambda_position: 1.0
119
+ lambda_opacity: 0.0001
120
+ lambda_scales: 0.0001
121
+ lambda_tv_loss: 1.0
122
+ lambda_depth_tv_loss: 1.0
123
+ lambda_scaling: 0.0
124
+
125
+ trainer:
126
+ max_steps: 1500
127
+ log_every_n_steps: 1
128
+ num_sanity_val_steps: 110
129
+ val_check_interval: 500
130
+ enable_progress_bar: true
131
+ precision: 32-true
132
+
133
+ checkpoint:
134
+ save_last: true # save at each validation time
135
+ save_top_k: -1
136
+ every_n_train_steps: 1000
137
+ save_weights_only: true
138
+ # every_n_train_steps: ${trainer.max_steps}
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/exporter.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import threestudio
6
+ import torch
7
+ from threestudio.models.background.base import BaseBackground
8
+ from threestudio.models.exporters.base import Exporter, ExporterOutput
9
+ from threestudio.models.geometry.base import BaseGeometry
10
+ from threestudio.models.materials.base import BaseMaterial
11
+ from threestudio.models.mesh import Mesh
12
+ from threestudio.utils.rasterize import NVDiffRasterizerContext
13
+ from threestudio.utils.typing import *
14
+
15
+
16
+ @threestudio.register("gaussian-mesh-exporter")
17
+ class MeshExporter(Exporter):
18
+ @dataclass
19
+ class Config(Exporter.Config):
20
+ fmt: str = "obj"
21
+ save_name: str = "model"
22
+ save_video: bool = True
23
+
24
+ cfg: Config
25
+
26
+ def configure(
27
+ self,
28
+ geometry: BaseGeometry,
29
+ material: BaseMaterial,
30
+ background: BaseBackground,
31
+ ) -> None:
32
+ super().configure(geometry, material, background)
33
+
34
+ def __call__(self) -> List[ExporterOutput]:
35
+ mesh: Mesh = self.geometry.extract_mesh()
36
+ return self.export_obj(mesh)
37
+
38
+ def export_obj(self, mesh: Mesh) -> List[ExporterOutput]:
39
+ params = {"mesh": mesh}
40
+ return [
41
+ ExporterOutput(
42
+ save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params
43
+ )
44
+ ]
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/gaussian_base.py ADDED
@@ -0,0 +1,1469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+ import math
12
+ import os
13
+ import random
14
+ import sys
15
+ import argparse
16
+ from dataclasses import dataclass, field
17
+ from datetime import datetime
18
+ from typing import NamedTuple
19
+
20
+ import numpy as np
21
+ import cv2
22
+ from PIL import Image
23
+ import threestudio
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import torchvision
28
+ from transformers import pipeline
29
+ from plyfile import PlyData, PlyElement
30
+ from simple_knn._C import distCUDA2
31
+ import diffusers
32
+ from diffusers import StableDiffusionInpaintPipeline, AutoPipelineForInpainting
33
+ from threestudio.models.geometry.base import BaseGeometry
34
+ from threestudio.utils.misc import C
35
+ from threestudio.utils.typing import *
36
+ from segment_anything import sam_model_registry, SamPredictor
37
+ import matplotlib.pyplot as plt
38
+
39
+ from .gaussian_io import GaussianIO
40
+ import imageio
41
+
42
+ from scipy.spatial.transform import Rotation as R
43
+
44
+ REORDER_MTX = torch.tensor([
45
+ [0,0,0,1],
46
+ [1,0,0,0],
47
+ [0,1,0,0],
48
+ [0,0,1,0]
49
+ ]).cuda().float()
50
+
51
+ def build_rotation(r):
52
+ norm = torch.sqrt(
53
+ r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
54
+ )
55
+
56
+ q = r / norm[:, None]
57
+
58
+ R = torch.zeros((q.size(0), 3, 3), device="cuda")
59
+ r = q[:, 0]
60
+ x = q[:, 1]
61
+ y = q[:, 2]
62
+ z = q[:, 3]
63
+
64
+ R[:, 0, 0] = 1 - 2 * (y * y + z * z)
65
+ R[:, 0, 1] = 2 * (x * y - r * z)
66
+ R[:, 0, 2] = 2 * (x * z + r * y)
67
+ R[:, 1, 0] = 2 * (x * y + r * z)
68
+ R[:, 1, 1] = 1 - 2 * (x * x + z * z)
69
+ R[:, 1, 2] = 2 * (y * z - r * x)
70
+ R[:, 2, 0] = 2 * (x * z - r * y)
71
+ R[:, 2, 1] = 2 * (y * z + r * x)
72
+ R[:, 2, 2] = 1 - 2 * (x * x + y * y)
73
+ return R
74
+
75
+ def rotation_matrix(angle_x, angle_y, angle_z):
76
+ # Convert angles to radians
77
+ rad_x = torch.deg2rad(torch.tensor(angle_x))
78
+ rad_y = torch.deg2rad(torch.tensor(angle_y))
79
+ rad_z = torch.deg2rad(torch.tensor(angle_z))
80
+
81
+ # Compute sine and cosine of the angles
82
+ cos_x = torch.cos(rad_x)
83
+ sin_x = torch.sin(rad_x)
84
+ cos_y = torch.cos(rad_y)
85
+ sin_y = torch.sin(rad_y)
86
+ cos_z = torch.cos(rad_z)
87
+ sin_z = torch.sin(rad_z)
88
+
89
+ # Construct the rotation matrix
90
+ Rx = torch.tensor([[1, 0, 0],
91
+ [0, cos_x, -sin_x],
92
+ [0, sin_x, cos_x]])
93
+
94
+ Ry = torch.tensor([[cos_y, 0, sin_y],
95
+ [0, 1, 0],
96
+ [-sin_y, 0, cos_y]])
97
+
98
+ Rz = torch.tensor([[cos_z, -sin_z, 0],
99
+ [sin_z, cos_z, 0],
100
+ [0, 0, 1]])
101
+
102
+ # Combine the rotation matrices
103
+ rotation_matrix = Rz @ Ry @ Rx
104
+
105
+ return rotation_matrix
106
+
107
+ # from scipy.spatial import KDTree
108
+ #
109
+ # def distCUDA2(points):
110
+ # points_np = points.detach().cpu().float().numpy()
111
+ # dists, inds = KDTree(points_np).query(points_np, k=4)
112
+ # meanDists = (dists[:, 1:] ** 2).mean(1)
113
+ #
114
+ # return torch.tensor(meanDists, dtype=points.dtype, device=points.device)
115
+
116
+ sys.path.append('./utils/GeoWizard/geowizard')
117
+ from models.geowizard_pipeline import DepthNormalEstimationPipeline
118
+
119
+ C0 = 0.28209479177387814
120
+
121
+ def propagate(canvas):
122
+ H, W = canvas.shape
123
+ dx = [0, 1, 0, -1]
124
+ dy = [1, 0, -1, 0]
125
+ count = np.zeros_like(canvas)
126
+
127
+ while 1:
128
+ curr_mask = canvas > 0
129
+ if sum(sum(curr_mask)) == H * W:
130
+ break
131
+ expand_mask = (cv2.blur(curr_mask.astype(np.float32), (3, 3)) > 0)
132
+ x, y = np.where(np.logical_and(expand_mask, ~curr_mask))
133
+ old_canvas = canvas.copy()
134
+
135
+ for xx, yy in zip(x, y):
136
+ for i in range(4):
137
+ ref_x = xx + dx[i]
138
+ ref_y = yy + dy[i]
139
+ if 0<=ref_x<H and 0<=ref_y<W and old_canvas[ref_x, ref_y] != 0:
140
+ canvas[xx, yy] = old_canvas[ref_x, ref_y]
141
+ count[xx, yy] = count[ref_x, ref_y] + 1
142
+
143
+ weight = (count.max() - count) / count.max()
144
+ return canvas * weight
145
+
146
+ def save_pc(save_file, pts, color):
147
+ '''
148
+ pts: N, 3
149
+ color: N, 3
150
+ '''
151
+ if color.dtype == np.dtype('float64'):
152
+ color = (color * 255).astype(np.uint8)
153
+ with open(save_file, 'w') as f:
154
+ f.writelines((
155
+ "ply\n",
156
+ "format ascii 1.0\n",
157
+ "element vertex {}\n".format(pts.shape[0]),
158
+ "property float x\n",
159
+ "property float y\n",
160
+ "property float z\n",
161
+ "property uchar red\n",
162
+ "property uchar green\n",
163
+ "property uchar blue\n",
164
+ "end_header\n"))
165
+ for i in range(pts.shape[0]):
166
+ point = "%f %f %f %d %d %d\n" % (pts[i, 0], pts[i, 1], pts[i, 2], color[i, 0], color[i, 1], color[i, 2])
167
+ f.writelines(point)
168
+ threestudio.info(f"Saved point cloud to {save_file}.")
169
+
170
+
171
+ def RGB2SH(rgb):
172
+ return (rgb - 0.5) / C0
173
+
174
+
175
+ def SH2RGB(sh):
176
+ return sh * C0 + 0.5
177
+
178
+
179
+ def inverse_sigmoid(x):
180
+ return torch.log(x / (1 - x))
181
+
182
+
183
+ def strip_lowerdiag(L):
184
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
185
+
186
+ uncertainty[:, 0] = L[:, 0, 0]
187
+ uncertainty[:, 1] = L[:, 0, 1]
188
+ uncertainty[:, 2] = L[:, 0, 2]
189
+ uncertainty[:, 3] = L[:, 1, 1]
190
+ uncertainty[:, 4] = L[:, 1, 2]
191
+ uncertainty[:, 5] = L[:, 2, 2]
192
+ return uncertainty
193
+
194
+
195
+ def strip_symmetric(sym):
196
+ return strip_lowerdiag(sym)
197
+
198
+
199
+ def gaussian_3d_coeff(xyzs, covs):
200
+ # xyzs: [N, 3]
201
+ # covs: [N, 6]
202
+ x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
203
+ a, b, c, d, e, f = (
204
+ covs[:, 0],
205
+ covs[:, 1],
206
+ covs[:, 2],
207
+ covs[:, 3],
208
+ covs[:, 4],
209
+ covs[:, 5],
210
+ )
211
+
212
+ # eps must be small enough !!!
213
+ inv_det = 1 / (
214
+ a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24
215
+ )
216
+ inv_a = (d * f - e**2) * inv_det
217
+ inv_b = (e * c - b * f) * inv_det
218
+ inv_c = (e * b - c * d) * inv_det
219
+ inv_d = (a * f - c**2) * inv_det
220
+ inv_e = (b * c - e * a) * inv_det
221
+ inv_f = (a * d - b**2) * inv_det
222
+
223
+ power = (
224
+ -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f)
225
+ - x * y * inv_b
226
+ - x * z * inv_c
227
+ - y * z * inv_e
228
+ )
229
+
230
+ power[power > 0] = -1e10 # abnormal values... make weights 0
231
+
232
+ return torch.exp(power)
233
+
234
+
235
+ def build_rotation(r):
236
+ norm = torch.sqrt(
237
+ r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
238
+ )
239
+
240
+ q = r / norm[:, None]
241
+
242
+ R = torch.zeros((q.size(0), 3, 3), device="cuda")
243
+
244
+ r = q[:, 0]
245
+ x = q[:, 1]
246
+ y = q[:, 2]
247
+ z = q[:, 3]
248
+
249
+ R[:, 0, 0] = 1 - 2 * (y * y + z * z)
250
+ R[:, 0, 1] = 2 * (x * y - r * z)
251
+ R[:, 0, 2] = 2 * (x * z + r * y)
252
+ R[:, 1, 0] = 2 * (x * y + r * z)
253
+ R[:, 1, 1] = 1 - 2 * (x * x + z * z)
254
+ R[:, 1, 2] = 2 * (y * z - r * x)
255
+ R[:, 2, 0] = 2 * (x * z - r * y)
256
+ R[:, 2, 1] = 2 * (y * z + r * x)
257
+ R[:, 2, 2] = 1 - 2 * (x * x + y * y)
258
+ return R
259
+
260
+
261
+ def build_scaling_rotation(s, r):
262
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
263
+ R = build_rotation(r)
264
+
265
+ L[:, 0, 0] = s[:, 0]
266
+ L[:, 1, 1] = s[:, 1]
267
+ L[:, 2, 2] = s[:, 2]
268
+
269
+ L = R @ L
270
+ return L
271
+
272
+
273
+ def safe_state(silent):
274
+ old_f = sys.stdout
275
+
276
+ class F:
277
+ def __init__(self, silent):
278
+ self.silent = silent
279
+
280
+ def write(self, x):
281
+ if not self.silent:
282
+ if x.endswith("\n"):
283
+ old_f.write(
284
+ x.replace(
285
+ "\n",
286
+ " [{}]\n".format(
287
+ str(datetime.now().strftime("%d/%m %H:%M:%S"))
288
+ ),
289
+ )
290
+ )
291
+ else:
292
+ old_f.write(x)
293
+
294
+ def flush(self):
295
+ old_f.flush()
296
+
297
+ sys.stdout = F(silent)
298
+
299
+ random.seed(0)
300
+ np.random.seed(0)
301
+ torch.manual_seed(0)
302
+ torch.cuda.set_device(torch.device("cuda:0"))
303
+
304
+
305
+ class BasicPointCloud(NamedTuple):
306
+ points: np.array
307
+ colors: np.array
308
+ normals: np.array
309
+
310
+
311
+ class Camera(NamedTuple):
312
+ FoVx: torch.Tensor
313
+ FoVy: torch.Tensor
314
+ camera_center: torch.Tensor
315
+ image_width: int
316
+ image_height: int
317
+ world_view_transform: torch.Tensor
318
+ full_proj_transform: torch.Tensor
319
+
320
+ def fill_mask(mask):
321
+ mask = np.array(mask)
322
+ canvas = np.zeros_like(mask)
323
+ H, W = mask.shape
324
+ for i in range(H):
325
+ for p in range(0, W):
326
+ if mask[i, p]:
327
+ canvas[i, p] = 1
328
+ else:
329
+ break
330
+ for p in range(W-1, 0, -1):
331
+ if mask[i, p]:
332
+ canvas[i, p] = 1
333
+ else:
334
+ break
335
+
336
+ for i in range(W):
337
+ for p in range(0, H):
338
+ if mask[p, i]:
339
+ canvas[p, i] = 1
340
+ else:
341
+ break
342
+ for p in range(H-1, 0, -1):
343
+ if mask[p, i]:
344
+ canvas[p, i] = 1
345
+ else:
346
+ break
347
+ mask = np.logical_and(mask, canvas)
348
+ return Image.fromarray(mask)
349
+
350
+ def parse_wh(wh):
351
+ try:
352
+ W, H = wh
353
+ except:
354
+ W = H = wh
355
+ return W, H
356
+
357
+ @threestudio.register("gaussian-splatting")
358
+ class GaussianBaseModel(BaseGeometry, GaussianIO):
359
+ @dataclass
360
+ class Config(BaseGeometry.Config):
361
+ max_num: int = 500000
362
+ sh_degree: int = 0
363
+ position_lr: Any = 0.001
364
+ # scale_lr: Any = 0.003
365
+ feature_lr: Any = 0.01
366
+ opacity_lr: Any = 0.05
367
+ scaling_lr: Any = 0.005
368
+ rotation_lr: Any = 0.005
369
+ pred_normal: bool = False
370
+ normal_lr: Any = 0.001
371
+ lang_lr: float = 0.005
372
+
373
+ densification_interval: int = 50
374
+ prune_interval: int = 50
375
+ opacity_reset_interval: int = 100000
376
+ densify_from_iter: int = 100
377
+ prune_from_iter: int = 100
378
+ densify_until_iter: int = 2000
379
+ prune_until_iter: int = 2000
380
+ densify_grad_threshold: Any = 0.01
381
+ min_opac_prune: Any = 0.005
382
+ split_thresh: Any = 0.02
383
+ radii2d_thresh: Any = 1000
384
+
385
+ sphere: bool = False
386
+ prune_big_points: bool = False
387
+ color_clip: Any = 2.0
388
+
389
+ geometry_convert_from: str = ""
390
+ load_ply_only_vertex: bool = False
391
+ init_num_pts: int = 100
392
+ pc_init_radius: float = 0.8
393
+ opacity_init: float = 0.1
394
+
395
+ img_resolution: Any = 512
396
+
397
+ shap_e_guidance_config: dict = field(default_factory=dict)
398
+
399
+ max_scaling: float = 100
400
+ sam_ckpt_path: str = "ckpts/sam_vit_h_4b8939.pth"
401
+ ooi_bbox: Any = None
402
+
403
+ prompt: Any = None
404
+ empty_prompt: Any = None
405
+ lang_beta_1: float = 0.9
406
+ lang_beta_2: float = 0.999
407
+
408
+ inference_only: bool = False
409
+ pc_max_resolution: int = 512
410
+
411
+ use_sdxl_for_inpaint: bool = False
412
+
413
+ cfg: Config
414
+
415
+ def setup_functions(self):
416
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
417
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
418
+ actual_covariance = L @ L.transpose(1, 2)
419
+ symm = strip_symmetric(actual_covariance)
420
+ return symm
421
+
422
+ self.scaling_activation = torch.exp
423
+ self.scaling_inverse_activation = torch.log
424
+
425
+ self.covariance_activation = build_covariance_from_scaling_rotation
426
+
427
+ self.opacity_activation = torch.sigmoid
428
+ self.inverse_opacity_activation = inverse_sigmoid
429
+
430
+ self.rotation_activation = torch.nn.functional.normalize
431
+ self.color_clip = C(self.cfg.color_clip, 0, 0)
432
+
433
+ self.fixed_xyz = None
434
+ self.fixed_rot = None
435
+
436
+ if not self.cfg.inference_only:
437
+ sam = sam_model_registry["vit_h"](checkpoint=self.cfg.sam_ckpt_path).to('cuda')
438
+ self.predictor = SamPredictor(sam)
439
+
440
+ def project_pc(self, c2w, H=256, W=None):
441
+ if W is None:
442
+ W = H
443
+ B = c2w.shape[0]
444
+
445
+ mask = torch.zeros([B, H, W], device='cuda')
446
+ depth_canvas = torch.zeros([B, H, W], device='cuda')
447
+
448
+ # for pc in [self.bg_point_cloud, self.point_cloud]:
449
+ pc_cam = torch.einsum('bxy,ny->bnx', torch.linalg.inv(c2w), self.point_cloud)
450
+ depth = -1 * pc_cam[..., 2].view(pc_cam.shape[0], -1)
451
+ pc_cam = (pc_cam / pc_cam[..., 2:3])[..., :3]
452
+ pc_2d = torch.einsum('xy,bny->bnx', self.proj_mtx, pc_cam).clamp(0, 1)
453
+ pc_2d[..., 0] = pc_2d[..., 0] * (W-1)
454
+ pc_2d[..., 1] = pc_2d[..., 1] * (H-1)
455
+ pc_2d = pc_2d.long()
456
+ for i in range(pc_2d.shape[0]):
457
+ x = (W - pc_2d[i, :, 0]).clamp(0, W-1)
458
+ y = (pc_2d[i, :, 1]).clamp(0, H-1)
459
+ unique_id = x * H + y
460
+ map_2d = np.zeros((W+1)*(H+1)) + 1e8
461
+ np.minimum.at(map_2d, unique_id.cpu(), depth[i].cpu())
462
+ map_2d[map_2d==1e8] = 0
463
+ positive_unique_id = np.where(map_2d>0)[0]
464
+ x, y = positive_unique_id // H, positive_unique_id % H
465
+ mask[i, y, x] = 1.0
466
+ depth_canvas[i, y, x] = torch.tensor(map_2d[positive_unique_id], device='cuda', dtype=torch.float)
467
+ # depth_canvas[i, y, x] = depth[i]
468
+
469
+ # pc_cam = torch.einsum('bxy,hwy->bhwx', torch.linalg.inv(c2w), self.point_cloud)
470
+ # depth = -1 * pc_cam[..., 2].view(pc_cam.shape[0], -1)
471
+ # pc_cam = (pc_cam / pc_cam[..., 2:3])[..., :3]
472
+ # pc_2d = torch.einsum('xy,bhwy->bhwx', self.proj_mtx, pc_cam).clamp(0, 1)
473
+ # pc_2d[..., 0] = pc_2d[..., 0] * (W-1)
474
+ # pc_2d[..., 1] = pc_2d[..., 1] * (H-1)
475
+ # pc_2d = (pc_2d.long()).view(pc_2d.shape[0], -1, pc_2d.shape[-1])
476
+
477
+
478
+ # mask = self.blur_kernel(mask) > 0
479
+ mask = torchvision.transforms.functional.gaussian_blur(mask, 3) > 0
480
+ # mask = mask > 0
481
+ return mask, depth_canvas
482
+
483
+ def img2pc_inpaint(self, img, c2w=None, gt_depth=None, mask=None, proj_func=None):
484
+ W, H = parse_wh(self.cfg.img_resolution)
485
+ if max(W, H) > self.cfg.pc_max_resolution:
486
+ W, H = int(W / max(W, H) * self.cfg.pc_max_resolution), int(H / max(W, H) * self.cfg.pc_max_resolution)
487
+
488
+ with torch.no_grad():
489
+ self.geowizard_pipe.to('cuda')
490
+ depth = self.geowizard_pipe(
491
+ img,
492
+ denoising_steps = 25,
493
+ ensemble_size = 3,
494
+ processing_res = 768,
495
+ match_input_res = False,
496
+ domain = 'outdoor',
497
+ color_map = 'Spectral',
498
+ gt_depth = gt_depth, mask = mask,
499
+ show_progress_bar = True)['depth_np']
500
+ self.geowizard_pipe.to('cpu')
501
+ ret_depth = depth.copy()
502
+ depth = torch.from_numpy(depth)[None]
503
+ depth = torch.nn.functional.interpolate(depth[None], size=(H, W), mode='bilinear', align_corners=True).squeeze()
504
+
505
+ depth = depth.cpu().numpy()
506
+ if proj_func is None:
507
+ depth = depth * 20 + 5
508
+ else:
509
+ depth = proj_func(depth)
510
+
511
+ depth = depth * -1
512
+ x, y = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
513
+ x = x / float(W-1)
514
+ y = y / float(H-1)
515
+ xyz = np.stack((x, y, np.ones_like(x)), 0).transpose(1, 2, 0)
516
+ xyz[..., 0] = 1 - xyz[..., 0]
517
+
518
+ fov = 60 / 180 * np.pi
519
+ proj_mtx = np.array([
520
+ [1 / (2 * np.tan(fov/2)), 0, 1/2],
521
+ [0, 1 / (2 * np.tan(fov/2)), 1/2],
522
+ [0, 0, 1],
523
+ ])
524
+ self.proj_mtx = torch.from_numpy(proj_mtx).cuda().float()
525
+ if c2w is None:
526
+ c2w = np.array([0.0000, 0.0000, 1.0000, 2.5000, 1.0000, 0.0000, -0.0000, 0.0000, -0.0000, 1.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]).reshape(4, 4)
527
+ else:
528
+ c2w = c2w[0].cpu().numpy()
529
+ xyz = np.einsum('ab,hwb->hwa', np.linalg.inv(proj_mtx), xyz)
530
+ xyz = xyz * depth[..., None]
531
+ xyz = np.concatenate([xyz, np.ones_like(x)[..., None]], 2)
532
+ xyz = np.einsum('ab,hwb->hwa', c2w, xyz)
533
+ return xyz, ret_depth
534
+
535
+ def inpaint(self, img, mask, prompt):
536
+ # inpaint using base pipe
537
+ N = 512
538
+ img = img.convert("RGB").resize((N, N))
539
+ mask = mask.convert("RGB").resize((N, N))
540
+ self.base_inpainting_pipe.to("cuda")
541
+ img = self.base_inpainting_pipe(prompt=prompt, image=img, mask_image=mask, guidance_scale=7.5).images[0]
542
+ self.base_inpainting_pipe.to("cpu")
543
+ torch.cuda.empty_cache()
544
+
545
+ if self.cfg.use_sdxl_for_inpaint:
546
+ # inpaint using sdxl pipe
547
+ N = 1024
548
+ img = img.convert("RGB").resize((N, N))
549
+ mask = mask.convert("RGB").resize((N, N))
550
+ self.sdxl_inpainting_pipe.to("cuda")
551
+ img = self.sdxl_inpainting_pipe(prompt=prompt, image=img, mask_image=mask, guidance_scale=7.5, num_inference_steps=20, strength=0.99).images[0]
552
+ self.sdxl_inpainting_pipe.to("cpu")
553
+
554
+ return img
555
+
556
+ def configure(self) -> None:
557
+ super().configure()
558
+ self.active_sh_degree = 0
559
+ self.max_sh_degree = self.cfg.sh_degree
560
+ self._xyz = torch.empty(0)
561
+ self._features_dc = torch.empty(0)
562
+ self._features_rest = torch.empty(0)
563
+ self._scaling = torch.empty(0)
564
+ self._rotation = torch.empty(0)
565
+ self._opacity = torch.empty(0)
566
+ self._opacity_mask = None
567
+ self.max_radii2D = torch.empty(0)
568
+ self.xyz_gradient_accum = torch.empty(0)
569
+ self.denom = torch.empty(0)
570
+ self.noise_ratio = 0.0
571
+ if self.cfg.pred_normal:
572
+ self._normal = torch.empty(0)
573
+ self.optimizer = None
574
+ self.setup_functions()
575
+ self.save_path = None
576
+ self.fixed_xyz = None
577
+ self.fixed_rot = None
578
+
579
+ if self.cfg.inference_only:
580
+ return
581
+ # setup GeoWizard
582
+ geowizard_checkpoint_path = 'lemonaddie/geowizard'
583
+ self.geowizard_pipe = DepthNormalEstimationPipeline.from_pretrained(
584
+ geowizard_checkpoint_path, torch_dtype=torch.float32)
585
+
586
+ self.base_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16)
587
+ # self.base_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None)
588
+ if self.cfg.use_sdxl_for_inpaint:
589
+ self.sdxl_inpainting_pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16")
590
+ self.sdxl_inpainting_pipe.scheduler = diffusers.EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
591
+
592
+ if self.cfg.geometry_convert_from.startswith("depth:"):
593
+ # estimate depth
594
+ W, H = parse_wh(self.cfg.img_resolution)
595
+ if max(W, H) > self.cfg.pc_max_resolution:
596
+ W, H = int(W / max(W, H) * self.cfg.pc_max_resolution), int(H / max(W, H) * self.cfg.pc_max_resolution)
597
+ img = self.cfg.geometry_convert_from[len("depth:"):]
598
+ raw_img = img = Image.open(img).convert("RGB")
599
+ img = img.resize((W, H))
600
+
601
+ bg_xyz, bg_color = [], []
602
+
603
+ with torch.no_grad():
604
+ self.predictor.set_image(np.array(raw_img))
605
+ self.ooi_masks = []
606
+ total_inp_ooi_masks = None
607
+ total_ooi_masks = []
608
+ for i in range(len(self.cfg.ooi_bbox) // 4):
609
+ bbox = np.array(self.cfg.ooi_bbox[4*i:4*i+4])
610
+ masks, _, _ = self.predictor.predict(
611
+ point_coords=None,
612
+ point_labels=None,
613
+ box=bbox[None, :],
614
+ multimask_output=False,
615
+ )
616
+ # plt.imshow(masks[0])
617
+ # plt.savefig(os.path.join(self.save_path, f'mask_{i}.png'))
618
+ ooi_masks = np.array(Image.fromarray(masks[0]).resize((W, H), Image.NEAREST))
619
+ ooi_masks = (cv2.blur(ooi_masks.astype(np.float32), (5, 5)) > 0)
620
+ inp_ooi_masks = (cv2.blur(ooi_masks.astype(np.float32), (7, 7)) > 0)
621
+ if i == 0:
622
+ total_inp_ooi_masks = inp_ooi_masks
623
+ else:
624
+ total_inp_ooi_masks += inp_ooi_masks
625
+ total_ooi_masks.append(ooi_masks)
626
+
627
+ total_inp_ooi_masks = total_inp_ooi_masks > 0
628
+ original_wh = parse_wh(self.cfg.img_resolution)
629
+ bg_image = self.inpaint(img=img, mask=Image.fromarray(total_inp_ooi_masks), prompt=self.cfg.empty_prompt).resize((original_wh))
630
+ self.bg_image = np.array(bg_image)
631
+ self.bg_image_mask = np.array(Image.fromarray(total_inp_ooi_masks).resize((original_wh)))
632
+
633
+ xyz, depth = self.img2pc_inpaint(img)
634
+ self.point_cloud = torch.from_numpy(xyz).cuda().float().reshape(-1, 4)
635
+
636
+ for ooi_masks in total_ooi_masks:
637
+ transit_masks = np.logical_and(cv2.blur(ooi_masks.astype(np.float32), (3, 3)) > 0, ~ooi_masks)
638
+ depth_tensor = torch.from_numpy(depth)[None, None].cuda() * 2 - 1
639
+ self.ooi_masks.append(torch.tensor(ooi_masks.reshape(-1).astype(np.uint8), device='cuda').float().bool())
640
+ ooi_masks = cv2.blur(ooi_masks.astype(np.float32), (9, 9)) > 0
641
+ mask = torch.from_numpy(ooi_masks.astype(np.float32))[None, None].cuda()
642
+ bg_xyz_pc, _ = self.img2pc_inpaint(bg_image, gt_depth=depth_tensor, mask=1-mask)
643
+
644
+ bg_xyz.append(bg_xyz_pc[ooi_masks])
645
+ bg_color.append(np.array(bg_image.resize((W, H)))[ooi_masks] / 255)
646
+
647
+ # xyz = xyz[..., :3].reshape(-1, 3)
648
+ xyz = xyz.reshape(-1, 4)
649
+ color = np.array(img).reshape(-1, 3) / 255
650
+ bg_xyz = np.concatenate(bg_xyz, 0)
651
+ additional_pts_num = bg_xyz.shape[0]
652
+ xyz = np.concatenate([xyz, bg_xyz], 0)
653
+ self.point_cloud = torch.from_numpy(xyz).cuda().float()
654
+
655
+ color = np.concatenate([color, np.concatenate(bg_color, 0)], 0)
656
+ for i in range(len(self.ooi_masks)):
657
+ self.register_buffer(f"ooi_masks_{i}", torch.cat([self.ooi_masks[i], torch.zeros([additional_pts_num], device='cuda').bool()]) )
658
+ self.ooi_masks[i] = getattr(self, f"ooi_masks_{i}")
659
+ self.register_buffer(f"_delete_mask", torch.ones_like(self.ooi_masks[0].float()))
660
+
661
+ # project to 3D space
662
+ xyz = xyz[:, :3]
663
+ color = color
664
+ pcd = BasicPointCloud(
665
+ points=xyz, colors=color, normals=np.zeros((xyz.shape[0], 3))
666
+ )
667
+ self.create_from_pcd(pcd, 10)
668
+ self.training_setup()
669
+
670
+ elif self.cfg.geometry_convert_from.startswith("shap-e:"):
671
+ shap_e_guidance = threestudio.find("shap-e-guidance")(
672
+ self.cfg.shap_e_guidance_config
673
+ )
674
+ prompt = self.cfg.geometry_convert_from[len("shap-e:") :]
675
+ xyz, color = shap_e_guidance(prompt)
676
+
677
+ pcd = BasicPointCloud(
678
+ points=xyz, colors=color, normals=np.zeros((xyz.shape[0], 3))
679
+ )
680
+ self.create_from_pcd(pcd, 10)
681
+ self.training_setup()
682
+
683
+ # Support Initialization from OpenLRM, Please see https://github.com/Adamdad/threestudio-lrm
684
+ elif self.cfg.geometry_convert_from.startswith("lrm:"):
685
+ lrm_guidance = threestudio.find("lrm-guidance")(
686
+ self.cfg.shap_e_guidance_config
687
+ )
688
+ prompt = self.cfg.geometry_convert_from[len("lrm:") :]
689
+ xyz, color = lrm_guidance(prompt)
690
+
691
+ pcd = BasicPointCloud(
692
+ points=xyz, colors=color, normals=np.zeros((xyz.shape[0], 3))
693
+ )
694
+ self.create_from_pcd(pcd, 10)
695
+ self.training_setup()
696
+
697
+ elif os.path.exists(self.cfg.geometry_convert_from):
698
+ threestudio.info(
699
+ "Loading point cloud from %s" % self.cfg.geometry_convert_from
700
+ )
701
+ if self.cfg.geometry_convert_from.endswith(".ckpt"):
702
+ ckpt_dict = torch.load(self.cfg.geometry_convert_from)
703
+ num_pts = ckpt_dict["state_dict"]["geometry._xyz"].shape[0]
704
+ pcd = BasicPointCloud(
705
+ points=np.zeros((num_pts, 3)),
706
+ colors=np.zeros((num_pts, 3)),
707
+ normals=np.zeros((num_pts, 3)),
708
+ )
709
+ self.create_from_pcd(pcd, 10)
710
+ self.training_setup()
711
+ new_ckpt_dict = {}
712
+ for key in self.state_dict():
713
+ if ckpt_dict["state_dict"].__contains__("geometry." + key):
714
+ new_ckpt_dict[key] = ckpt_dict["state_dict"]["geometry." + key]
715
+ else:
716
+ new_ckpt_dict[key] = self.state_dict()[key]
717
+ self.load_state_dict(new_ckpt_dict)
718
+ elif self.cfg.geometry_convert_from.endswith(".ply"):
719
+ if self.cfg.load_ply_only_vertex:
720
+ plydata = PlyData.read(self.cfg.geometry_convert_from)
721
+ vertices = plydata["vertex"]
722
+ positions = np.vstack(
723
+ [vertices["x"], vertices["y"], vertices["z"]]
724
+ ).T
725
+ if vertices.__contains__("red"):
726
+ colors = (
727
+ np.vstack(
728
+ [vertices["red"], vertices["green"], vertices["blue"]]
729
+ ).T
730
+ / 255.0
731
+ )
732
+ else:
733
+ shs = np.random.random((positions.shape[0], 3)) / 255.0
734
+ C0 = 0.28209479177387814
735
+ colors = shs * C0 + 0.5
736
+ normals = np.zeros_like(positions)
737
+ pcd = BasicPointCloud(
738
+ points=positions, colors=colors, normals=normals
739
+ )
740
+ self.create_from_pcd(pcd, 10)
741
+ else:
742
+ self.load_ply(self.cfg.geometry_convert_from)
743
+ self.training_setup()
744
+ else:
745
+ threestudio.info("Geometry not found, initilization with random points")
746
+ num_pts = self.cfg.init_num_pts
747
+ phis = np.random.random((num_pts,)) * 2 * np.pi
748
+ costheta = np.random.random((num_pts,)) * 2 - 1
749
+ thetas = np.arccos(costheta)
750
+ mu = np.random.random((num_pts,))
751
+ radius = self.cfg.pc_init_radius * np.cbrt(mu)
752
+ x = radius * np.sin(thetas) * np.cos(phis)
753
+ y = radius * np.sin(thetas) * np.sin(phis)
754
+ z = radius * np.cos(thetas)
755
+ xyz = np.stack((x, y, z), axis=1)
756
+
757
+ shs = np.random.random((num_pts, 3)) / 255.0
758
+ C0 = 0.28209479177387814
759
+ color = shs * C0 + 0.5
760
+ pcd = BasicPointCloud(
761
+ points=xyz, colors=color, normals=np.zeros((num_pts, 3))
762
+ )
763
+
764
+ self.create_from_pcd(pcd, 10)
765
+ self.training_setup()
766
+
767
+ def add_pc_from_novel_view(self, rgb, mask, depth, c2w, save_path=None):
768
+ W, H = parse_wh(self.cfg.img_resolution)
769
+ if max(W, H) > self.cfg.pc_max_resolution:
770
+ W, H = int(W / max(W, H) * self.cfg.pc_max_resolution), int(H / max(W, H) * self.cfg.pc_max_resolution)
771
+ # depth estimation -> add points.
772
+ mask = fill_mask(mask)
773
+ blur_mask = Image.fromarray(cv2.blur(np.array(mask).astype(np.float32), (7, 7)) > 0)
774
+ res = self.inpaint(img=rgb, mask=blur_mask, prompt=self.side_prompt)
775
+
776
+ self.geowizard_pipe.to('cuda')
777
+ depth_unaligned = self.geowizard_pipe(
778
+ res,
779
+ denoising_steps = 25,
780
+ ensemble_size = 3,
781
+ processing_res = 768,
782
+ match_input_res = False,
783
+ domain = 'outdoor',
784
+ color_map = 'Spectral',
785
+ gt_depth = None, mask = None,
786
+ show_progress_bar = True)['depth_np']
787
+ self.geowizard_pipe.to('cpu')
788
+ prev_depth = depth_unaligned[~np.array(mask.resize((768,768)))]
789
+ # inpaint the depth map
790
+ depth_nd = depth[0].cpu().numpy().astype(np.uint8)
791
+ inpaint_mask = np.logical_and(~np.array(mask) , depth[0].cpu().numpy().astype(np.uint8)==0 ).astype(np.uint8)
792
+ l, r = depth[depth>0].min().item(), depth.max().item()
793
+ depth = (depth - l) / (r - l) * 255
794
+ depth = cv2.inpaint(depth[0].cpu().numpy().astype(np.uint8), inpaint_mask, 3, cv2.INPAINT_TELEA)
795
+ depth = torch.tensor(depth)[None].cuda().float() / 255
796
+ reproj_func = lambda x: (x - prev_depth.min().item()) / (prev_depth.max().item() - prev_depth.min().item()) * (r-l) + l
797
+ depth = depth * (prev_depth.max() - prev_depth.min()) + prev_depth.min()
798
+ depth_tensor = torch.nn.functional.interpolate(depth[None].cuda(), 768, mode='nearest') * 2 - 1
799
+
800
+ _masks = cv2.blur(np.array(mask.resize((768, 768))).astype(float), (20, 20)) > 0
801
+ mask_tensor = torch.from_numpy(_masks.astype(np.float32))[None, None].cuda()
802
+ bg_xyz_pc, _ = self.img2pc_inpaint(res, gt_depth=depth_tensor, mask=1-mask_tensor, proj_func=reproj_func, c2w=c2w)
803
+
804
+ mask = np.array(Image.fromarray(_masks).resize((W, H)))
805
+ new_xyz = bg_xyz_pc[mask][:, :3]
806
+ res = res.resize((W, H))
807
+ new_color = np.array(res)[mask] / 255
808
+ pcd = BasicPointCloud(points=new_xyz, colors=new_color, normals=np.zeros((new_xyz.shape[0], 3)))
809
+ self.merge_from_pcd(pcd, 10)
810
+
811
+ original_wh = parse_wh(self.cfg.img_resolution)
812
+ return res.resize((original_wh)), Image.fromarray(_masks).resize((original_wh))
813
+
814
+ @property
815
+ def get_scaling(self):
816
+ if self.cfg.sphere:
817
+ return self.scaling_activation(
818
+ torch.mean(self._scaling, dim=-1).unsqueeze(-1).repeat(1, 3)
819
+ ).clip(0, self.cfg.max_scaling)
820
+ return self.scaling_activation(self._scaling).clip(0, self.cfg.max_scaling)
821
+
822
+ @property
823
+ def get_rotation(self):
824
+ return self.rotation_activation(self._rotation)
825
+
826
+ @property
827
+ def get_language_feature(self):
828
+ return self._language_feature
829
+
830
+ @property
831
+ def get_xyz(self):
832
+ ret = self._xyz
833
+ if self.noise_ratio > 0.0:
834
+ offset = torch.zeros_like(ret)
835
+ for idx in range(len(self.ooi_masks)):
836
+ ooi_masks = getattr(self, f"ooi_masks_{idx}")
837
+ offset[ooi_masks] = torch.rand(3, device='cuda') * self.noise_ratio
838
+ return ret
839
+
840
+ @property
841
+ def get_features(self):
842
+ features_dc = self._features_dc
843
+ features_dc = features_dc.clip(-self.color_clip, self.color_clip)
844
+ features_rest = self._features_rest
845
+ return torch.cat((features_dc, features_rest), dim=1)
846
+
847
+ @property
848
+ def get_opacity(self):
849
+ if self._opacity_mask is None:
850
+ ret = self.opacity_activation(self._opacity)
851
+ else:
852
+ ret = self.opacity_activation(self._opacity) * self._opacity_mask.unsqueeze(-1)
853
+
854
+ if self._delete_mask is None:
855
+ return ret
856
+ else:
857
+ return ret * self._delete_mask.unsqueeze(-1)
858
+
859
+ @property
860
+ def get_normal(self):
861
+ if self.cfg.pred_normal:
862
+ return self._normal
863
+ else:
864
+ raise ValueError("Normal is not predicted")
865
+
866
+ def recover_xyzrot(self):
867
+ self._xyz = torch.nn.Parameter(self.fixed_xyz)
868
+ self._rotation = torch.nn.Parameter(self.fixed_rot)
869
+
870
+ def random_rotate(self, rotate_aug_scale, apply_rotate):
871
+ if self.fixed_xyz is None:
872
+ self.fixed_xyz = self.get_xyz.data
873
+ self.fixed_rot = self.get_rotation.data
874
+
875
+ if apply_rotate:
876
+ ooi_mask = self.ooi_masks_0.view(-1).byte().to(device='cuda').float()
877
+
878
+ rotate = random.randint(-rotate_aug_scale, rotate_aug_scale)
879
+ rot_matrix = rotation_matrix(0, 0, rotate).cuda()
880
+ prev_xyz = self.fixed_xyz.clone()
881
+ ooi_xyz = prev_xyz[ooi_mask.bool()]
882
+ mean = ooi_xyz.mean(0)
883
+ ooi_xyz = ooi_xyz - mean
884
+ after_xyz = torch.einsum('ab,nb->na', rot_matrix, ooi_xyz) + mean
885
+ prev_xyz[ooi_mask.bool()] = after_xyz
886
+ self._xyz = torch.nn.Parameter(prev_xyz)
887
+
888
+ prev_rotation = self.fixed_rot.clone()
889
+ prev_rotation_mtx = build_rotation(prev_rotation)
890
+ after_rotation_mtx = torch.einsum('ab,nbc->nac', rot_matrix, prev_rotation_mtx)
891
+ after_rotation = torch.from_numpy(R.from_matrix(after_rotation_mtx.detach().cpu()).as_quat()).cuda().float()
892
+ after_rotation = torch.einsum('ab,nb->na', REORDER_MTX, after_rotation)
893
+ prev_rotation[ooi_mask.bool()] = after_rotation[ooi_mask.bool()]
894
+ self._rotation = torch.nn.Parameter(prev_rotation)
895
+ else:
896
+ self.recover_xyzrot()
897
+
898
+ def get_covariance(self, scaling_modifier=1):
899
+ return self.covariance_activation(
900
+ self.get_scaling, scaling_modifier, self._rotation
901
+ )
902
+
903
+ def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
904
+ self.spatial_lr_scale = spatial_lr_scale
905
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
906
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
907
+ features = (
908
+ torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))
909
+ .float()
910
+ .cuda()
911
+ )
912
+ features[:, :3, 0] = fused_color
913
+ features[:, 3:, 1:] = 0.0
914
+
915
+ threestudio.info(
916
+ f"Number of points at initialisation:{fused_point_cloud.shape[0]}"
917
+ )
918
+
919
+ dist2 = torch.clamp_min(
920
+ distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),
921
+ 0.0000001,
922
+ )
923
+ scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
924
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
925
+ rots[:, 0] = 1
926
+
927
+ opacities = inverse_sigmoid(
928
+ self.cfg.opacity_init
929
+ * torch.ones(
930
+ (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"
931
+ )
932
+ )
933
+
934
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
935
+ self._features_dc = nn.Parameter(
936
+ features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)
937
+ )
938
+ self._features_rest = nn.Parameter(
939
+ features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)
940
+ )
941
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
942
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
943
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
944
+ if self.cfg.pred_normal:
945
+ normals = torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
946
+ self._normal = nn.Parameter(normals.requires_grad_(True))
947
+ self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
948
+
949
+ self.fused_point_cloud = fused_point_cloud.cpu().clone().detach()
950
+ self.features = features.cpu().clone().detach()
951
+ self.scales = scales.cpu().clone().detach()
952
+ self.rots = rots.cpu().clone().detach()
953
+ self.opacities = opacities.cpu().clone().detach()
954
+
955
+ language_feature = torch.zeros((self._xyz.shape[0], 3), device="cuda")
956
+ self._language_feature = torch.nn.Parameter(language_feature.requires_grad_(True))
957
+
958
+ def merge_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
959
+ self.spatial_lr_scale = spatial_lr_scale
960
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
961
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
962
+ features = (
963
+ torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))
964
+ .float()
965
+ .cuda()
966
+ )
967
+ features[:, :3, 0] = fused_color
968
+ features[:, 3:, 1:] = 0.0
969
+
970
+ threestudio.info(
971
+ f"Number of points at merging:{fused_point_cloud.shape[0]}"
972
+ )
973
+
974
+ dist2 = torch.clamp_min(
975
+ distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),
976
+ 0.0000001,
977
+ )
978
+ scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
979
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
980
+ rots[:, 0] = 1
981
+
982
+ opacities = inverse_sigmoid(
983
+ self.cfg.opacity_init
984
+ * torch.ones(
985
+ (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"
986
+ )
987
+ )
988
+ self.densification_postfix(
989
+ fused_point_cloud,
990
+ features[:, :, 0:1].transpose(1, 2).contiguous(),
991
+ features[:, :, 1:].transpose(1, 2).contiguous(),
992
+ opacities,
993
+ scales,
994
+ rots,
995
+ None,
996
+ torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
997
+ )
998
+
999
+ for idx in range(len(self.ooi_masks)):
1000
+ # self.ooi_masks[idx] = torch.cat([self.ooi_masks[idx], torch.ones([fused_point_cloud.shape[0]], device='cuda') > 0])
1001
+ self.register_buffer(f"ooi_masks_{idx}", torch.cat([getattr(self, f"ooi_masks_{idx}"), torch.zeros([fused_point_cloud.shape[0]], device='cuda').bool()]) )
1002
+ self.ooi_masks[idx] = getattr(self, f"ooi_masks_{idx}")
1003
+ self.register_buffer(f"_delete_mask", torch.ones_like(self.ooi_masks[0].float()))
1004
+
1005
+ # self._xyz = torch.nn.Parameter(torch.cat([self._xyz, fused_point_cloud],0),requires_grad=True)
1006
+ # self._features_dc = torch.nn.Parameter(torch.cat([self._features_dc, features[:, :, 0:1].transpose(1, 2).contiguous()],0),requires_grad=True)
1007
+ # self._features_rest = torch.nn.Parameter(torch.cat([self._features_rest, features[:, :, 1:].transpose(1, 2).contiguous()],0),requires_grad=True)
1008
+ # self._scaling = torch.nn.Parameter(torch.cat([self._scaling, scales],0),requires_grad=True)
1009
+ # self._rotation = torch.nn.Parameter(torch.cat([self._rotation, rots],0),requires_grad=True)
1010
+ # self._opacity = torch.nn.Parameter(torch.cat([self._opacity, opacities],0),requires_grad=True)
1011
+
1012
+ # if self.cfg.pred_normal:
1013
+ # normals = torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
1014
+ # self._normal = nn.Parameter(normals.requires_grad_(True))
1015
+ # self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
1016
+
1017
+ # self.fused_point_cloud = fused_point_cloud.cpu().clone().detach()
1018
+ # self.features = features.cpu().clone().detach()
1019
+ # self.scales = scales.cpu().clone().detach()
1020
+ # self.rots = rots.cpu().clone().detach()
1021
+ # self.opacities = opacities.cpu().clone().detach()
1022
+
1023
+ # language_feature = torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
1024
+ # self._language_feature = torch.nn.Parameter(torch.cat([self._language_feature, language_feature], 0), requires_grad=True)
1025
+ # self.training_setup()
1026
+
1027
+
1028
+ def lang_training_setup(self):
1029
+ training_args = self.cfg
1030
+ l = [
1031
+ {'params': [self._language_feature], 'lr': C(training_args.lang_lr, 0, 0)},
1032
+ ]
1033
+ self._xyz.requires_grad_(False)
1034
+ self._features_dc.requires_grad_(False)
1035
+ self._features_rest.requires_grad_(False)
1036
+ self._scaling.requires_grad_(False)
1037
+ self._rotation.requires_grad_(False)
1038
+ self._opacity.requires_grad_(False)
1039
+ self._language_feature.requires_grad_(True)
1040
+ # self.lang_optimizer = torch.optim.SGD(l, lr=0.0)
1041
+ self.lang_optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15, betas=(self.cfg.lang_beta_1, self.cfg.lang_beta_2))
1042
+ self.optimize_params = ["lang"]
1043
+ self.optimize_list = l
1044
+
1045
+ def after_lang(self):
1046
+ self._xyz.requires_grad_(True)
1047
+ self._features_dc.requires_grad_(True)
1048
+ self._features_rest.requires_grad_(True)
1049
+ self._scaling.requires_grad_(True)
1050
+ self._rotation.requires_grad_(True)
1051
+ self._opacity.requires_grad_(True)
1052
+ self._language_feature.requires_grad_(False)
1053
+
1054
+ def training_setup(self):
1055
+ self._xyz.requires_grad_(True)
1056
+ self._features_dc.requires_grad_(True)
1057
+ self._features_rest.requires_grad_(True)
1058
+ self._scaling.requires_grad_(True)
1059
+ self._rotation.requires_grad_(True)
1060
+ self._opacity.requires_grad_(True)
1061
+ self._language_feature.requires_grad_(False)
1062
+ training_args = self.cfg
1063
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
1064
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
1065
+
1066
+ l = [
1067
+ {
1068
+ "params": [self._xyz],
1069
+ "lr": C(training_args.position_lr, 0, 0),
1070
+ "name": "xyz",
1071
+ },
1072
+ {
1073
+ "params": [self._features_dc],
1074
+ "lr": C(training_args.feature_lr, 0, 0),
1075
+ "name": "f_dc",
1076
+ },
1077
+ {
1078
+ "params": [self._features_rest],
1079
+ "lr": C(training_args.feature_lr, 0, 0) / 20.0,
1080
+ "name": "f_rest",
1081
+ },
1082
+ {
1083
+ "params": [self._opacity],
1084
+ "lr": C(training_args.opacity_lr, 0, 0),
1085
+ "name": "opacity",
1086
+ },
1087
+ {
1088
+ "params": [self._scaling],
1089
+ "lr": C(training_args.scaling_lr, 0, 0),
1090
+ "name": "scaling",
1091
+ },
1092
+ {
1093
+ "params": [self._rotation],
1094
+ "lr": C(training_args.rotation_lr, 0, 0),
1095
+ "name": "rotation",
1096
+ },
1097
+ {'params': [self._language_feature], 'lr': C(training_args.lang_lr, 0, 0), "name": "language_feature"},
1098
+ ]
1099
+ if self.cfg.pred_normal:
1100
+ l.append(
1101
+ {
1102
+ "params": [self._normal],
1103
+ "lr": C(training_args.normal_lr, 0, 0),
1104
+ "name": "normal",
1105
+ },
1106
+ )
1107
+
1108
+ self.optimize_params = [
1109
+ "xyz",
1110
+ "f_dc",
1111
+ "f_rest",
1112
+ "opacity",
1113
+ "scaling",
1114
+ "rotation",
1115
+ "language_feature"
1116
+ ]
1117
+ self.optimize_list = l
1118
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
1119
+ self.lang_optimizer = None
1120
+
1121
+ def merge_optimizer(self, net_optimizer):
1122
+ l = self.optimize_list
1123
+ for param in net_optimizer.param_groups:
1124
+ l.append(
1125
+ {
1126
+ "params": param["params"],
1127
+ "lr": param["lr"],
1128
+ }
1129
+ )
1130
+ self.optimizer = torch.optim.Adam(l, lr=0.0)
1131
+ return self.optimizer
1132
+
1133
+ def update_learning_rate(self, iteration):
1134
+ """Learning rate scheduling per step"""
1135
+ for param_group in self.optimizer.param_groups:
1136
+ if not ("name" in param_group):
1137
+ continue
1138
+ if param_group["name"] == "xyz":
1139
+ param_group["lr"] = C(
1140
+ self.cfg.position_lr, 0, iteration, interpolation="exp"
1141
+ )
1142
+ if param_group["name"] == "scaling":
1143
+ param_group["lr"] = C(
1144
+ self.cfg.scaling_lr, 0, iteration, interpolation="exp"
1145
+ )
1146
+ if param_group["name"] == "f_dc":
1147
+ param_group["lr"] = C(
1148
+ self.cfg.feature_lr, 0, iteration, interpolation="exp"
1149
+ )
1150
+ if param_group["name"] == "f_rest":
1151
+ param_group["lr"] = (
1152
+ C(self.cfg.feature_lr, 0, iteration, interpolation="exp") / 20.0
1153
+ )
1154
+ if param_group["name"] == "opacity":
1155
+ param_group["lr"] = C(
1156
+ self.cfg.opacity_lr, 0, iteration, interpolation="exp"
1157
+ )
1158
+ if param_group["name"] == "rotation":
1159
+ param_group["lr"] = C(
1160
+ self.cfg.rotation_lr, 0, iteration, interpolation="exp"
1161
+ )
1162
+ if param_group["name"] == "normal":
1163
+ param_group["lr"] = C(
1164
+ self.cfg.normal_lr, 0, iteration, interpolation="exp"
1165
+ )
1166
+ if self.lang_optimizer is not None:
1167
+ for param_group in self.lang_optimizer.param_groups:
1168
+ if not ("name" in param_group):
1169
+ continue
1170
+ if param_group["name"] == "language_feature":
1171
+ param_group["lr"] = C(
1172
+ self.cfg.lang_lr, 0, iteration, interpolation="exp"
1173
+ )
1174
+ self.color_clip = C(self.cfg.color_clip, 0, iteration)
1175
+
1176
+ def reset_opacity(self):
1177
+ # opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
1178
+ opacities_new = inverse_sigmoid(self.get_opacity * 0.9)
1179
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
1180
+ self._opacity = optimizable_tensors["opacity"]
1181
+
1182
+ def to(self, device="cpu"):
1183
+ self._xyz = self._xyz.to(device)
1184
+ self._features_dc = self._features_dc.to(device)
1185
+ self._features_rest = self._features_rest.to(device)
1186
+ self._opacity = self._opacity.to(device)
1187
+ self._scaling = self._scaling.to(device)
1188
+ self._rotation = self._rotation.to(device)
1189
+ self._normal = self._normal.to(device)
1190
+ self._language_feature = self._language_feature.to(device)
1191
+
1192
+ def replace_tensor_to_optimizer(self, tensor, name):
1193
+ optimizable_tensors = {}
1194
+ for group in self.optimizer.param_groups:
1195
+ if ("name" in group) and group["name"] == name:
1196
+ stored_state = self.optimizer.state.get(group["params"][0], None)
1197
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
1198
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
1199
+
1200
+ del self.optimizer.state[group["params"][0]]
1201
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
1202
+ self.optimizer.state[group["params"][0]] = stored_state
1203
+
1204
+ optimizable_tensors[group["name"]] = group["params"][0]
1205
+ return optimizable_tensors
1206
+
1207
+ def _prune_optimizer(self, mask):
1208
+ optimizable_tensors = {}
1209
+ for group in self.optimizer.param_groups:
1210
+ if ("name" in group) and (group["name"] in self.optimize_params):
1211
+ stored_state = self.optimizer.state.get(group["params"][0], None)
1212
+ if stored_state is not None:
1213
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
1214
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
1215
+
1216
+ del self.optimizer.state[group["params"][0]]
1217
+ group["params"][0] = nn.Parameter(
1218
+ (group["params"][0][mask].requires_grad_(True))
1219
+ )
1220
+ self.optimizer.state[group["params"][0]] = stored_state
1221
+
1222
+ optimizable_tensors[group["name"]] = group["params"][0]
1223
+ else:
1224
+ group["params"][0] = nn.Parameter(
1225
+ group["params"][0][mask].requires_grad_(True)
1226
+ )
1227
+ optimizable_tensors[group["name"]] = group["params"][0]
1228
+ return optimizable_tensors
1229
+
1230
+ def prune_points(self, mask):
1231
+ valid_points_mask = ~mask
1232
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
1233
+
1234
+ self._xyz = optimizable_tensors["xyz"]
1235
+ self._features_dc = optimizable_tensors["f_dc"]
1236
+ self._features_rest = optimizable_tensors["f_rest"]
1237
+ self._opacity = optimizable_tensors["opacity"]
1238
+ self._scaling = optimizable_tensors["scaling"]
1239
+ self._rotation = optimizable_tensors["rotation"]
1240
+ self._language_feature = optimizable_tensors["language_feature"]
1241
+ if self.cfg.pred_normal:
1242
+ self._normal = optimizable_tensors["normal"]
1243
+
1244
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
1245
+
1246
+ self.denom = self.denom[valid_points_mask]
1247
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
1248
+
1249
+ def cat_tensors_to_optimizer(self, tensors_dict):
1250
+ optimizable_tensors = {}
1251
+ for group in self.optimizer.param_groups:
1252
+ if ("name" in group) and (group["name"] in self.optimize_params):
1253
+ extension_tensor = tensors_dict[group["name"]]
1254
+ stored_state = self.optimizer.state.get(group["params"][0], None)
1255
+ if stored_state is not None:
1256
+ stored_state["exp_avg"] = torch.cat(
1257
+ (stored_state["exp_avg"], torch.zeros_like(extension_tensor)),
1258
+ dim=0,
1259
+ )
1260
+ stored_state["exp_avg_sq"] = torch.cat(
1261
+ (
1262
+ stored_state["exp_avg_sq"],
1263
+ torch.zeros_like(extension_tensor),
1264
+ ),
1265
+ dim=0,
1266
+ )
1267
+
1268
+ del self.optimizer.state[group["params"][0]]
1269
+ group["params"][0] = nn.Parameter(
1270
+ torch.cat(
1271
+ (group["params"][0], extension_tensor), dim=0
1272
+ ).requires_grad_(True)
1273
+ )
1274
+ self.optimizer.state[group["params"][0]] = stored_state
1275
+
1276
+ optimizable_tensors[group["name"]] = group["params"][0]
1277
+ else:
1278
+ group["params"][0] = nn.Parameter(
1279
+ torch.cat(
1280
+ (group["params"][0], extension_tensor), dim=0
1281
+ ).requires_grad_(True)
1282
+ )
1283
+ optimizable_tensors[group["name"]] = group["params"][0]
1284
+
1285
+ return optimizable_tensors
1286
+
1287
+ def densification_postfix(
1288
+ self,
1289
+ new_xyz,
1290
+ new_features_dc,
1291
+ new_features_rest,
1292
+ new_opacities,
1293
+ new_scaling,
1294
+ new_rotation,
1295
+ new_normal=None,
1296
+ new_language_feature=None
1297
+ ):
1298
+ d = {
1299
+ "xyz": new_xyz,
1300
+ "f_dc": new_features_dc,
1301
+ "f_rest": new_features_rest,
1302
+ "opacity": new_opacities,
1303
+ "scaling": new_scaling,
1304
+ "rotation": new_rotation,
1305
+ "language_feature": new_language_feature,
1306
+ }
1307
+ if self.cfg.pred_normal:
1308
+ d.update({"normal": new_normal})
1309
+
1310
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
1311
+ self._xyz = optimizable_tensors["xyz"]
1312
+ self._features_dc = optimizable_tensors["f_dc"]
1313
+ self._features_rest = optimizable_tensors["f_rest"]
1314
+ self._opacity = optimizable_tensors["opacity"]
1315
+ self._scaling = optimizable_tensors["scaling"]
1316
+ self._rotation = optimizable_tensors["rotation"]
1317
+ self._language_feature = optimizable_tensors["language_feature"]
1318
+ if self.cfg.pred_normal:
1319
+ self._normal = optimizable_tensors["normal"]
1320
+
1321
+ self.xyz_gradient_accum = torch.zeros((self._xyz.shape[0], 1), device="cuda")
1322
+ self.denom = torch.zeros((self._xyz.shape[0], 1), device="cuda")
1323
+ self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
1324
+
1325
+ def densify_and_split(self, grads, grad_threshold, N=2):
1326
+ n_init_points = self._xyz.shape[0]
1327
+ # Extract points that satisfy the gradient condition
1328
+ padded_grad = torch.zeros((n_init_points), device="cuda")
1329
+ padded_grad[: grads.shape[0]] = grads.squeeze()
1330
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
1331
+ selected_pts_mask = torch.logical_and(
1332
+ selected_pts_mask,
1333
+ torch.norm(self.get_scaling, dim=1) > self.cfg.split_thresh,
1334
+ )
1335
+
1336
+ # divide N to enhance robustness
1337
+ stds = self.get_scaling[selected_pts_mask].repeat(N, 1) / N
1338
+ means = torch.zeros((stds.size(0), 3), device="cuda")
1339
+ samples = torch.normal(mean=means, std=stds)
1340
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)
1341
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self._xyz[
1342
+ selected_pts_mask
1343
+ ].repeat(N, 1)
1344
+ new_scaling = self.scaling_inverse_activation(
1345
+ self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)
1346
+ )
1347
+ new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
1348
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
1349
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)
1350
+ new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)
1351
+ new_language_feature = self._language_feature[selected_pts_mask].repeat(N,1)
1352
+ if self.cfg.pred_normal:
1353
+ new_normal = self._normal[selected_pts_mask].repeat(N, 1)
1354
+ else:
1355
+ new_normal = None
1356
+
1357
+ self.densification_postfix(
1358
+ new_xyz,
1359
+ new_features_dc,
1360
+ new_features_rest,
1361
+ new_opacity,
1362
+ new_scaling,
1363
+ new_rotation,
1364
+ new_normal,
1365
+ new_language_feature
1366
+ )
1367
+
1368
+ prune_filter = torch.cat(
1369
+ (
1370
+ selected_pts_mask,
1371
+ torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool),
1372
+ )
1373
+ )
1374
+ self.prune_points(prune_filter)
1375
+
1376
+ def densify_and_clone(self, grads, grad_threshold):
1377
+ # Extract points that satisfy the gradient condition
1378
+ selected_pts_mask = torch.where(
1379
+ torch.norm(grads, dim=-1) >= grad_threshold, True, False
1380
+ )
1381
+ selected_pts_mask = torch.logical_and(
1382
+ selected_pts_mask,
1383
+ torch.norm(self.get_scaling, dim=1) <= self.cfg.split_thresh,
1384
+ )
1385
+
1386
+ new_xyz = self._xyz[selected_pts_mask]
1387
+ new_features_dc = self._features_dc[selected_pts_mask]
1388
+ new_features_rest = self._features_rest[selected_pts_mask]
1389
+ new_opacities = self._opacity[selected_pts_mask]
1390
+ new_scaling = self._scaling[selected_pts_mask]
1391
+ new_rotation = self._rotation[selected_pts_mask]
1392
+ new_language_feature = self._language_feature[selected_pts_mask]
1393
+ if self.cfg.pred_normal:
1394
+ new_normal = self._normal[selected_pts_mask]
1395
+ else:
1396
+ new_normal = None
1397
+
1398
+ self.densification_postfix(
1399
+ new_xyz,
1400
+ new_features_dc,
1401
+ new_features_rest,
1402
+ new_opacities,
1403
+ new_scaling,
1404
+ new_rotation,
1405
+ new_normal,
1406
+ new_language_feature
1407
+ )
1408
+
1409
+ def densify(self, max_grad):
1410
+ grads = self.xyz_gradient_accum / self.denom
1411
+ grads[grads.isnan()] = 0.0
1412
+
1413
+ self.densify_and_clone(grads, max_grad)
1414
+ self.densify_and_split(grads, max_grad)
1415
+
1416
+ def prune(self, min_opacity, max_screen_size):
1417
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
1418
+ if self.cfg.prune_big_points:
1419
+ big_points_vs = self.max_radii2D > (torch.mean(self.max_radii2D) * 3)
1420
+ prune_mask = torch.logical_or(prune_mask, big_points_vs)
1421
+ self.prune_points(prune_mask)
1422
+
1423
+ torch.cuda.empty_cache()
1424
+
1425
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
1426
+ self.xyz_gradient_accum[update_filter] += torch.norm(
1427
+ viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True
1428
+ )
1429
+ self.denom[update_filter] += 1
1430
+
1431
+ @torch.no_grad()
1432
+ def update_states(
1433
+ self,
1434
+ iteration,
1435
+ visibility_filter,
1436
+ radii,
1437
+ viewspace_point_tensor,
1438
+ ):
1439
+ if self._xyz.shape[0] >= self.cfg.max_num + 100:
1440
+ prune_mask = torch.randperm(self._xyz.shape[0]).to(self._xyz.device)
1441
+ prune_mask = prune_mask > self.cfg.max_num
1442
+ self.prune_points(prune_mask)
1443
+ return
1444
+ # Keep track of max radii in image-space for pruning
1445
+ # loop over batch
1446
+ bs = len(viewspace_point_tensor)
1447
+ for i in range(bs):
1448
+ radii_i = radii[i]
1449
+ visibility_filter_i = visibility_filter[i]
1450
+ viewspace_point_tensor_i = viewspace_point_tensor[i]
1451
+ self.max_radii2D = torch.max(self.max_radii2D, radii_i.float())
1452
+
1453
+ self.add_densification_stats(viewspace_point_tensor_i, visibility_filter_i)
1454
+
1455
+ if (
1456
+ iteration > self.cfg.prune_from_iter
1457
+ and iteration < self.cfg.prune_until_iter
1458
+ and iteration % self.cfg.prune_interval == 0
1459
+ ):
1460
+ self.prune(self.cfg.min_opac_prune, self.cfg.radii2d_thresh)
1461
+ if iteration % self.cfg.opacity_reset_interval == 0:
1462
+ self.reset_opacity()
1463
+
1464
+ if (
1465
+ iteration > self.cfg.densify_from_iter
1466
+ and iteration < self.cfg.densify_until_iter
1467
+ and iteration % self.cfg.densification_interval == 0
1468
+ ):
1469
+ self.densify(self.cfg.densify_grad_threshold)
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/gaussian_base.py.bak ADDED
@@ -0,0 +1,1492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+ import math
12
+ import os
13
+ import random
14
+ import sys
15
+ import argparse
16
+ from dataclasses import dataclass, field
17
+ from datetime import datetime
18
+ from typing import NamedTuple
19
+
20
+ import numpy as np
21
+ import cv2
22
+ from PIL import Image
23
+ import threestudio
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import torchvision
28
+ from transformers import pipeline
29
+ from plyfile import PlyData, PlyElement
30
+ from simple_knn._C import distCUDA2
31
+ import diffusers
32
+ from diffusers import StableDiffusionInpaintPipeline, AutoPipelineForInpainting
33
+ from threestudio.models.geometry.base import BaseGeometry
34
+ from threestudio.utils.misc import C
35
+ from threestudio.utils.typing import *
36
+ from segment_anything import sam_model_registry, SamPredictor
37
+ import matplotlib.pyplot as plt
38
+
39
+ from .gaussian_io import GaussianIO
40
+ import imageio
41
+
42
+ from scipy.spatial.transform import Rotation as R
43
+
44
+ REORDER_MTX = torch.tensor([
45
+ [0,0,0,1],
46
+ [1,0,0,0],
47
+ [0,1,0,0],
48
+ [0,0,1,0]
49
+ ]).cuda().float()
50
+
51
+ def build_rotation(r):
52
+ norm = torch.sqrt(
53
+ r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
54
+ )
55
+
56
+ q = r / norm[:, None]
57
+
58
+ R = torch.zeros((q.size(0), 3, 3), device="cuda")
59
+ r = q[:, 0]
60
+ x = q[:, 1]
61
+ y = q[:, 2]
62
+ z = q[:, 3]
63
+
64
+ R[:, 0, 0] = 1 - 2 * (y * y + z * z)
65
+ R[:, 0, 1] = 2 * (x * y - r * z)
66
+ R[:, 0, 2] = 2 * (x * z + r * y)
67
+ R[:, 1, 0] = 2 * (x * y + r * z)
68
+ R[:, 1, 1] = 1 - 2 * (x * x + z * z)
69
+ R[:, 1, 2] = 2 * (y * z - r * x)
70
+ R[:, 2, 0] = 2 * (x * z - r * y)
71
+ R[:, 2, 1] = 2 * (y * z + r * x)
72
+ R[:, 2, 2] = 1 - 2 * (x * x + y * y)
73
+ return R
74
+
75
+ def rotation_matrix(angle_x, angle_y, angle_z):
76
+ # Convert angles to radians
77
+ rad_x = torch.deg2rad(torch.tensor(angle_x))
78
+ rad_y = torch.deg2rad(torch.tensor(angle_y))
79
+ rad_z = torch.deg2rad(torch.tensor(angle_z))
80
+
81
+ # Compute sine and cosine of the angles
82
+ cos_x = torch.cos(rad_x)
83
+ sin_x = torch.sin(rad_x)
84
+ cos_y = torch.cos(rad_y)
85
+ sin_y = torch.sin(rad_y)
86
+ cos_z = torch.cos(rad_z)
87
+ sin_z = torch.sin(rad_z)
88
+
89
+ # Construct the rotation matrix
90
+ Rx = torch.tensor([[1, 0, 0],
91
+ [0, cos_x, -sin_x],
92
+ [0, sin_x, cos_x]])
93
+
94
+ Ry = torch.tensor([[cos_y, 0, sin_y],
95
+ [0, 1, 0],
96
+ [-sin_y, 0, cos_y]])
97
+
98
+ Rz = torch.tensor([[cos_z, -sin_z, 0],
99
+ [sin_z, cos_z, 0],
100
+ [0, 0, 1]])
101
+
102
+ # Combine the rotation matrices
103
+ rotation_matrix = Rz @ Ry @ Rx
104
+
105
+ return rotation_matrix
106
+
107
+ # from scipy.spatial import KDTree
108
+ #
109
+ # def distCUDA2(points):
110
+ # points_np = points.detach().cpu().float().numpy()
111
+ # dists, inds = KDTree(points_np).query(points_np, k=4)
112
+ # meanDists = (dists[:, 1:] ** 2).mean(1)
113
+ #
114
+ # return torch.tensor(meanDists, dtype=points.dtype, device=points.device)
115
+
116
+ sys.path.append('./GeoWizard/geowizard')
117
+ from models.geowizard_pipeline import DepthNormalEstimationPipeline
118
+
119
+ C0 = 0.28209479177387814
120
+
121
+ def propagate(canvas):
122
+ H, W = canvas.shape
123
+ dx = [0, 1, 0, -1]
124
+ dy = [1, 0, -1, 0]
125
+ count = np.zeros_like(canvas)
126
+
127
+ while 1:
128
+ curr_mask = canvas > 0
129
+ if sum(sum(curr_mask)) == H * W:
130
+ break
131
+ expand_mask = (cv2.blur(curr_mask.astype(np.float32), (3, 3)) > 0)
132
+ x, y = np.where(np.logical_and(expand_mask, ~curr_mask))
133
+ old_canvas = canvas.copy()
134
+
135
+ for xx, yy in zip(x, y):
136
+ for i in range(4):
137
+ ref_x = xx + dx[i]
138
+ ref_y = yy + dy[i]
139
+ if 0<=ref_x<H and 0<=ref_y<W and old_canvas[ref_x, ref_y] != 0:
140
+ canvas[xx, yy] = old_canvas[ref_x, ref_y]
141
+ count[xx, yy] = count[ref_x, ref_y] + 1
142
+
143
+ weight = (count.max() - count) / count.max()
144
+ return canvas * weight
145
+
146
+ def save_pc(save_file, pts, color):
147
+ '''
148
+ pts: N, 3
149
+ color: N, 3
150
+ '''
151
+ if color.dtype == np.dtype('float64'):
152
+ color = (color * 255).astype(np.uint8)
153
+ with open(save_file, 'w') as f:
154
+ f.writelines((
155
+ "ply\n",
156
+ "format ascii 1.0\n",
157
+ "element vertex {}\n".format(pts.shape[0]),
158
+ "property float x\n",
159
+ "property float y\n",
160
+ "property float z\n",
161
+ "property uchar red\n",
162
+ "property uchar green\n",
163
+ "property uchar blue\n",
164
+ "end_header\n"))
165
+ for i in range(pts.shape[0]):
166
+ point = "%f %f %f %d %d %d\n" % (pts[i, 0], pts[i, 1], pts[i, 2], color[i, 0], color[i, 1], color[i, 2])
167
+ f.writelines(point)
168
+ threestudio.info(f"Saved point cloud to {save_file}.")
169
+
170
+
171
+ def RGB2SH(rgb):
172
+ return (rgb - 0.5) / C0
173
+
174
+
175
+ def SH2RGB(sh):
176
+ return sh * C0 + 0.5
177
+
178
+
179
+ def inverse_sigmoid(x):
180
+ return torch.log(x / (1 - x))
181
+
182
+
183
+ def strip_lowerdiag(L):
184
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
185
+
186
+ uncertainty[:, 0] = L[:, 0, 0]
187
+ uncertainty[:, 1] = L[:, 0, 1]
188
+ uncertainty[:, 2] = L[:, 0, 2]
189
+ uncertainty[:, 3] = L[:, 1, 1]
190
+ uncertainty[:, 4] = L[:, 1, 2]
191
+ uncertainty[:, 5] = L[:, 2, 2]
192
+ return uncertainty
193
+
194
+
195
+ def strip_symmetric(sym):
196
+ return strip_lowerdiag(sym)
197
+
198
+
199
+ def gaussian_3d_coeff(xyzs, covs):
200
+ # xyzs: [N, 3]
201
+ # covs: [N, 6]
202
+ x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
203
+ a, b, c, d, e, f = (
204
+ covs[:, 0],
205
+ covs[:, 1],
206
+ covs[:, 2],
207
+ covs[:, 3],
208
+ covs[:, 4],
209
+ covs[:, 5],
210
+ )
211
+
212
+ # eps must be small enough !!!
213
+ inv_det = 1 / (
214
+ a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24
215
+ )
216
+ inv_a = (d * f - e**2) * inv_det
217
+ inv_b = (e * c - b * f) * inv_det
218
+ inv_c = (e * b - c * d) * inv_det
219
+ inv_d = (a * f - c**2) * inv_det
220
+ inv_e = (b * c - e * a) * inv_det
221
+ inv_f = (a * d - b**2) * inv_det
222
+
223
+ power = (
224
+ -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f)
225
+ - x * y * inv_b
226
+ - x * z * inv_c
227
+ - y * z * inv_e
228
+ )
229
+
230
+ power[power > 0] = -1e10 # abnormal values... make weights 0
231
+
232
+ return torch.exp(power)
233
+
234
+
235
+ def build_rotation(r):
236
+ norm = torch.sqrt(
237
+ r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
238
+ )
239
+
240
+ q = r / norm[:, None]
241
+
242
+ R = torch.zeros((q.size(0), 3, 3), device="cuda")
243
+
244
+ r = q[:, 0]
245
+ x = q[:, 1]
246
+ y = q[:, 2]
247
+ z = q[:, 3]
248
+
249
+ R[:, 0, 0] = 1 - 2 * (y * y + z * z)
250
+ R[:, 0, 1] = 2 * (x * y - r * z)
251
+ R[:, 0, 2] = 2 * (x * z + r * y)
252
+ R[:, 1, 0] = 2 * (x * y + r * z)
253
+ R[:, 1, 1] = 1 - 2 * (x * x + z * z)
254
+ R[:, 1, 2] = 2 * (y * z - r * x)
255
+ R[:, 2, 0] = 2 * (x * z - r * y)
256
+ R[:, 2, 1] = 2 * (y * z + r * x)
257
+ R[:, 2, 2] = 1 - 2 * (x * x + y * y)
258
+ return R
259
+
260
+
261
+ def build_scaling_rotation(s, r):
262
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
263
+ R = build_rotation(r)
264
+
265
+ L[:, 0, 0] = s[:, 0]
266
+ L[:, 1, 1] = s[:, 1]
267
+ L[:, 2, 2] = s[:, 2]
268
+
269
+ L = R @ L
270
+ return L
271
+
272
+
273
+ def safe_state(silent):
274
+ old_f = sys.stdout
275
+
276
+ class F:
277
+ def __init__(self, silent):
278
+ self.silent = silent
279
+
280
+ def write(self, x):
281
+ if not self.silent:
282
+ if x.endswith("\n"):
283
+ old_f.write(
284
+ x.replace(
285
+ "\n",
286
+ " [{}]\n".format(
287
+ str(datetime.now().strftime("%d/%m %H:%M:%S"))
288
+ ),
289
+ )
290
+ )
291
+ else:
292
+ old_f.write(x)
293
+
294
+ def flush(self):
295
+ old_f.flush()
296
+
297
+ sys.stdout = F(silent)
298
+
299
+ random.seed(0)
300
+ np.random.seed(0)
301
+ torch.manual_seed(0)
302
+ torch.cuda.set_device(torch.device("cuda:0"))
303
+
304
+
305
+ class BasicPointCloud(NamedTuple):
306
+ points: np.array
307
+ colors: np.array
308
+ normals: np.array
309
+
310
+
311
+ class Camera(NamedTuple):
312
+ FoVx: torch.Tensor
313
+ FoVy: torch.Tensor
314
+ camera_center: torch.Tensor
315
+ image_width: int
316
+ image_height: int
317
+ world_view_transform: torch.Tensor
318
+ full_proj_transform: torch.Tensor
319
+
320
+ def fill_mask(mask):
321
+ mask = np.array(mask)
322
+ canvas = np.zeros_like(mask)
323
+ H, W = mask.shape
324
+ for i in range(H):
325
+ for p in range(0, W):
326
+ if mask[i, p]:
327
+ canvas[i, p] = 1
328
+ else:
329
+ break
330
+ for p in range(W-1, 0, -1):
331
+ if mask[i, p]:
332
+ canvas[i, p] = 1
333
+ else:
334
+ break
335
+
336
+ for i in range(W):
337
+ for p in range(0, H):
338
+ if mask[p, i]:
339
+ canvas[p, i] = 1
340
+ else:
341
+ break
342
+ for p in range(H-1, 0, -1):
343
+ if mask[p, i]:
344
+ canvas[p, i] = 1
345
+ else:
346
+ break
347
+ mask = np.logical_and(mask, canvas)
348
+ return Image.fromarray(mask)
349
+
350
+ def parse_wh(wh):
351
+ try:
352
+ W, H = wh
353
+ except:
354
+ H = W = wh
355
+ return W, H
356
+
357
+ @threestudio.register("gaussian-splatting")
358
+ class GaussianBaseModel(BaseGeometry, GaussianIO):
359
+ @dataclass
360
+ class Config(BaseGeometry.Config):
361
+ max_num: int = 500000
362
+ sh_degree: int = 0
363
+ position_lr: Any = 0.001
364
+ # scale_lr: Any = 0.003
365
+ feature_lr: Any = 0.01
366
+ opacity_lr: Any = 0.05
367
+ scaling_lr: Any = 0.005
368
+ rotation_lr: Any = 0.005
369
+ pred_normal: bool = False
370
+ normal_lr: Any = 0.001
371
+ lang_lr: float = 0.005
372
+
373
+ densification_interval: int = 50
374
+ prune_interval: int = 50
375
+ opacity_reset_interval: int = 100000
376
+ densify_from_iter: int = 100
377
+ prune_from_iter: int = 100
378
+ densify_until_iter: int = 2000
379
+ prune_until_iter: int = 2000
380
+ densify_grad_threshold: Any = 0.01
381
+ min_opac_prune: Any = 0.005
382
+ split_thresh: Any = 0.02
383
+ radii2d_thresh: Any = 1000
384
+
385
+ sphere: bool = False
386
+ prune_big_points: bool = False
387
+ color_clip: Any = 2.0
388
+
389
+ geometry_convert_from: str = ""
390
+ load_ply_only_vertex: bool = False
391
+ init_num_pts: int = 100
392
+ pc_init_radius: float = 0.8
393
+ opacity_init: float = 0.1
394
+
395
+ img_resolution: Any = 512
396
+
397
+ shap_e_guidance_config: dict = field(default_factory=dict)
398
+
399
+ max_scaling: float = 100
400
+ sam_ckpt_path: str = "ckpts/sam_vit_h_4b8939.pth"
401
+ ooi_bbox: Any = None
402
+
403
+ prompt: Any = None
404
+ empty_prompt: Any = None
405
+ novel_view_gradual: bool = False
406
+ lang_beta_1: float = 0.9
407
+ lang_beta_2: float = 0.999
408
+
409
+ inference_only: bool = False
410
+
411
+ cfg: Config
412
+
413
+ def setup_functions(self):
414
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
415
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
416
+ actual_covariance = L @ L.transpose(1, 2)
417
+ symm = strip_symmetric(actual_covariance)
418
+ return symm
419
+
420
+ self.scaling_activation = torch.exp
421
+ self.scaling_inverse_activation = torch.log
422
+
423
+ self.covariance_activation = build_covariance_from_scaling_rotation
424
+
425
+ self.opacity_activation = torch.sigmoid
426
+ self.inverse_opacity_activation = inverse_sigmoid
427
+
428
+ self.rotation_activation = torch.nn.functional.normalize
429
+ self.color_clip = C(self.cfg.color_clip, 0, 0)
430
+
431
+ self.fixed_xyz = None
432
+ self.fixed_rot = None
433
+
434
+ if not self.cfg.inference_only:
435
+ sam = sam_model_registry["vit_h"](checkpoint=self.cfg.sam_ckpt_path).to('cuda')
436
+ self.predictor = SamPredictor(sam)
437
+
438
+ def project_pc(self, c2w, H=None, W=None):
439
+ W, H = parse_wh(self.cfg.img_resolution)
440
+ # if W is None:
441
+ # W = H
442
+ assert self.point_cloud is not None
443
+ pc_cam = torch.einsum('bxy,hwy->bhwx', torch.linalg.inv(c2w), self.point_cloud)
444
+ depth = -1 * pc_cam[..., 2].view(pc_cam.shape[0], -1)
445
+ pc_cam = (pc_cam / pc_cam[..., 2:3])[..., :3]
446
+ pc_2d = torch.einsum('xy,bhwy->bhwx', self.proj_mtx, pc_cam).clamp(0, 1)
447
+ pc_2d[..., 0] = pc_2d[..., 0] * (W-1)
448
+ pc_2d[..., 1] = pc_2d[..., 1] * (H-1)
449
+ pc_2d = (pc_2d.long()).view(pc_2d.shape[0], -1, pc_2d.shape[-1])
450
+
451
+ mask = torch.zeros([pc_2d.shape[0], H, W], device='cuda')
452
+ depth_canvas = torch.zeros([pc_2d.shape[0], H, W], device='cuda')
453
+ for i in range(pc_2d.shape[0]):
454
+ x = (W - pc_2d[i, :, 0]).clamp(0, W-1)
455
+ y = (pc_2d[i, :, 1]).clamp(0, H-1)
456
+ mask[i, y, x] = 1.0
457
+ depth_canvas[i, y, x] = depth[i]
458
+
459
+ mask = torchvision.transforms.functional.gaussian_blur(mask, 3) > 0
460
+ return mask, depth_canvas
461
+
462
+ def img2pc_inpaint(self, img, c2w=None, gt_depth=None, mask=None, proj_func=None):
463
+ W, H = parse_wh(self.cfg.img_resolution)
464
+ with torch.no_grad():
465
+ depth = self.geowizard_pipe(
466
+ img,
467
+ denoising_steps = 25,
468
+ ensemble_size = 3,
469
+ processing_res = 768,
470
+ match_input_res = False,
471
+ domain = 'outdoor',
472
+ color_map = 'Spectral',
473
+ gt_depth = gt_depth, mask = mask,
474
+ show_progress_bar = True)['depth_np']
475
+ ret_depth = depth.copy()
476
+ depth = torch.from_numpy(depth)[None]
477
+ depth = torch.nn.functional.interpolate(depth[None], size=(H, W), mode='bilinear', align_corners=True).squeeze()
478
+
479
+ depth = depth.cpu().numpy()
480
+ if proj_func is None:
481
+ depth = depth * 20 + 5
482
+ else:
483
+ depth = proj_func(depth)
484
+
485
+ depth = depth * -1
486
+ x, y = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
487
+ x = x / float(W-1)
488
+ y = y / float(H-1)
489
+ xyz = np.stack((x, y, np.ones_like(x)), 0).transpose(1, 2, 0)
490
+ xyz[..., 0] = 1 - xyz[..., 0]
491
+
492
+ fov = 60 / 180 * np.pi
493
+ proj_mtx = np.array([
494
+ [1 / (2 * np.tan(fov/2)), 0, 1/2],
495
+ [0, 1 / (2 * np.tan(fov/2)), 1/2],
496
+ [0, 0, 1],
497
+ ])
498
+ self.proj_mtx = torch.from_numpy(proj_mtx).cuda().float()
499
+ if c2w is None:
500
+ c2w = np.array([0.0000, -0.3420, 0.9397, 2.3492, 1.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.9397, 0.3420, 0.8551, 0.0000, 0.0000, 0.0000, 1.0000]).reshape(4, 4)
501
+ c2w = np.array([0.0000, 0.0000, 1.0000, 2.5000, 1.0000, 0.0000, -0.0000, 0.0000, -0.0000, 1.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]).reshape(4, 4)
502
+ else:
503
+ c2w = c2w[0].cpu().numpy()
504
+ xyz = np.einsum('ab,hwb->hwa', np.linalg.inv(proj_mtx), xyz)
505
+ xyz = xyz * depth[..., None]
506
+ xyz = np.concatenate([xyz, np.ones_like(x)[..., None]], 2)
507
+ xyz = np.einsum('ab,hwb->hwa', c2w, xyz)
508
+ return xyz, ret_depth
509
+
510
+ def img2pc(self, img, transit_mask=None, fg_transit_l=None, fg_transit_r=None, c2w=None, fg_depth=None):
511
+ H, W = parse_hw(self.cfg.img_resolution)
512
+ with torch.no_grad():
513
+ depth = self.geowizard_pipe(
514
+ img,
515
+ denoising_steps = 25,
516
+ ensemble_size = 3,
517
+ processing_res = 768,
518
+ match_input_res = True,
519
+ domain = 'outdoor',
520
+ color_map = 'Spectral',
521
+ show_progress_bar = True)['depth_np']
522
+ depth = torch.from_numpy(depth)[None]
523
+ depth = torch.nn.functional.interpolate(depth[None], size=(W, H), mode='bilinear', align_corners=True).squeeze()
524
+
525
+
526
+ depth = depth.cpu().numpy()
527
+ if fg_depth is None:
528
+ if fg_transit_l is None:
529
+ l, r = np.quantile(depth, 0.05), np.quantile(depth, 0.95)
530
+ depth = (depth - l) / (r - l) * 20 * 0.9 + 2 + 5 / 90
531
+ ret_depth = depth.copy()
532
+ else:
533
+ transit_l, transit_r = depth[transit_mask].min(), depth[transit_mask].max()
534
+ depth = (depth - transit_l) / (transit_r - transit_l) * (fg_transit_r - fg_transit_l) + fg_transit_l
535
+ ret_depth = depth
536
+ else:
537
+ delta = fg_depth[0] - depth
538
+ delta[~transit_mask] = 0
539
+ delta = propagate(delta)
540
+ depth = depth + delta
541
+ ret_depth = depth.copy()
542
+ depth = depth * -1
543
+ x, y = np.meshgrid(np.arange(H, dtype=np.float32), np.arange(W, dtype=np.float32), indexing='xy')
544
+ x = x / float(H-1)
545
+ y = y / float(W-1)
546
+ xyz = np.stack((x, y, np.ones_like(x)), 0).transpose(1, 2, 0)
547
+ xyz[..., 0] = 1 - xyz[..., 0]
548
+
549
+ fov = 60 / 180 * np.pi
550
+ proj_mtx = np.array([
551
+ [1 / (2 * np.tan(fov/2)), 0, 1/2],
552
+ [0, 1 / (2 * np.tan(fov/2)), 1/2],
553
+ [0, 0, 1],
554
+ ])
555
+ self.proj_mtx = torch.from_numpy(proj_mtx).cuda().float()
556
+ if c2w is None:
557
+ c2w = np.array([0.0000, -0.3420, 0.9397, 2.3492, 1.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.9397, 0.3420, 0.8551, 0.0000, 0.0000, 0.0000, 1.0000]).reshape(4, 4)
558
+ c2w = np.array([0.0000, 0.0000, 1.0000, 2.5000, 1.0000, 0.0000, -0.0000, 0.0000, -0.0000, 1.0000, -0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]).reshape(4, 4)
559
+ else:
560
+ c2w = c2w[0].cpu().numpy()
561
+ xyz = np.einsum('ab,hwb->hwa', np.linalg.inv(proj_mtx), xyz)
562
+ xyz = xyz * depth[..., None]
563
+ xyz = np.concatenate([xyz, np.ones_like(x)[..., None]], 2)
564
+ xyz = np.einsum('ab,hwb->hwa', c2w, xyz)
565
+ return xyz, ret_depth
566
+
567
+ def inpaint(self, img, mask, prompt):
568
+ # inpaint using base pipe
569
+ N = 512
570
+ img = img.convert("RGB").resize((N, N))
571
+ mask = mask.convert("RGB").resize((N, N))
572
+ self.base_inpainting_pipe.to("cuda")
573
+ img = self.base_inpainting_pipe(prompt=prompt, image=img, mask_image=mask, guidance_scale=7.5).images[0]
574
+ self.base_inpainting_pipe.to("cpu")
575
+ torch.cuda.empty_cache()
576
+
577
+ # inpaint using sdxl pipe
578
+ N = 1024
579
+ img = img.convert("RGB").resize((N, N))
580
+ mask = mask.convert("RGB").resize((N, N))
581
+ self.sdxl_inpainting_pipe.to("cuda")
582
+ img = self.sdxl_inpainting_pipe(prompt=prompt, image=img, mask_image=mask, guidance_scale=7.5, num_inference_steps=20, strength=0.99).images[0]
583
+ self.sdxl_inpainting_pipe.to("cpu")
584
+
585
+ return img
586
+
587
+ def configure(self) -> None:
588
+ super().configure()
589
+ self.active_sh_degree = 0
590
+ self.max_sh_degree = self.cfg.sh_degree
591
+ self._xyz = torch.empty(0)
592
+ self._features_dc = torch.empty(0)
593
+ self._features_rest = torch.empty(0)
594
+ self._scaling = torch.empty(0)
595
+ self._rotation = torch.empty(0)
596
+ self._opacity = torch.empty(0)
597
+ self._opacity_mask = None
598
+ self.max_radii2D = torch.empty(0)
599
+ self.xyz_gradient_accum = torch.empty(0)
600
+ self.denom = torch.empty(0)
601
+ self.noise_ratio = 0.0
602
+ if self.cfg.pred_normal:
603
+ self._normal = torch.empty(0)
604
+ self.optimizer = None
605
+ self.setup_functions()
606
+ self.save_path = None
607
+ self.fixed_xyz = None
608
+ self.fixed_rot = None
609
+
610
+ if self.cfg.inference_only:
611
+ return
612
+ # setup GeoWizard
613
+ geowizard_checkpoint_path = 'lemonaddie/geowizard'
614
+ self.geowizard_pipe = DepthNormalEstimationPipeline.from_pretrained(
615
+ geowizard_checkpoint_path, torch_dtype=torch.float32).to(torch.device("cuda"))
616
+
617
+ self.base_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16)
618
+ # self.base_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None)
619
+ self.sdxl_inpainting_pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16")
620
+ self.sdxl_inpainting_pipe.scheduler = diffusers.EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
621
+
622
+ if self.cfg.geometry_convert_from.startswith("depth:"):
623
+ # estimate depth
624
+ W, H = parse_wh(self.cfg.img_resolution)
625
+ mask_H, mask_W = H, W
626
+ if max(H, W) > 1024:
627
+ mask_H, mask_W = int(H / max(H, W) * 1024), int(W / max(H, W) * 1024)
628
+ img = self.cfg.geometry_convert_from[len("depth:"):]
629
+ raw_img = img = Image.open(img).convert("RGB")
630
+ img = img.resize((W, H))
631
+
632
+ bg_xyz, bg_color = [], []
633
+
634
+ with torch.no_grad():
635
+ self.predictor.set_image(np.array(raw_img))
636
+ self.ooi_masks = []
637
+ total_inp_ooi_masks = None
638
+ total_ooi_masks = []
639
+ for i in range(len(self.cfg.ooi_bbox) // 4):
640
+ bbox = np.array(self.cfg.ooi_bbox[4*i:4*i+4])
641
+ masks, _, _ = self.predictor.predict(
642
+ point_coords=None,
643
+ point_labels=None,
644
+ box=bbox[None, :],
645
+ multimask_output=False,
646
+ )
647
+ # plt.imshow(masks[0])
648
+ # plt.savefig(os.path.join(self.save_path, f'mask_{i}.png'))
649
+ ooi_masks = np.array(Image.fromarray(masks[0]).resize((W, H), Image.NEAREST))
650
+ ooi_masks = (cv2.blur(ooi_masks.astype(np.float32), (5, 5)) > 0)
651
+ inp_ooi_masks = (cv2.blur(ooi_masks.astype(np.float32), (7, 7)) > 0)
652
+ if i == 0:
653
+ total_inp_ooi_masks = inp_ooi_masks
654
+ else:
655
+ total_inp_ooi_masks += inp_ooi_masks
656
+ total_ooi_masks.append(ooi_masks)
657
+
658
+ total_inp_ooi_masks = total_inp_ooi_masks > 0
659
+ bg_image = self.inpaint(img=img, mask=Image.fromarray(total_inp_ooi_masks), prompt=self.cfg.empty_prompt).resize((W, H))
660
+ self.bg_image = np.array(bg_image)
661
+ self.bg_image_mask = np.array(Image.fromarray(total_inp_ooi_masks).resize((W, H)))
662
+
663
+ xyz, depth = self.img2pc_inpaint(img)
664
+ self.point_cloud = torch.from_numpy(xyz).cuda().float()
665
+
666
+ for ooi_masks in total_ooi_masks:
667
+ transit_masks = np.logical_and(cv2.blur(ooi_masks.astype(np.float32), (3, 3)) > 0, ~ooi_masks)
668
+ depth_tensor = torch.from_numpy(depth)[None, None].cuda() * 2 - 1
669
+ self.ooi_masks.append(torch.tensor(ooi_masks.reshape(-1).astype(np.uint8), device='cuda').float().bool())
670
+ ooi_masks = cv2.blur(ooi_masks.astype(np.float32), (9, 9)) > 0
671
+ mask = torch.from_numpy(ooi_masks.astype(np.float32))[None, None].cuda()
672
+ bg_xyz_pc, _ = self.img2pc_inpaint(bg_image, gt_depth=depth_tensor, mask=1-mask)
673
+
674
+ bg_xyz.append(bg_xyz_pc[ooi_masks][:, :3])
675
+ bg_color.append(np.array(bg_image)[ooi_masks] / 255)
676
+
677
+ xyz = xyz[..., :3].reshape(-1, 3)
678
+ color = np.array(img).reshape(-1, 3) / 255
679
+ additional_pts_num = sum([len(each) for each in bg_xyz])
680
+ xyz = np.concatenate([xyz, np.concatenate(bg_xyz, 0)], 0)
681
+ color = np.concatenate([color, np.concatenate(bg_color, 0)], 0)
682
+ for i in range(len(self.ooi_masks)):
683
+ self.register_buffer(f"ooi_masks_{i}", torch.cat([self.ooi_masks[i], torch.zeros([additional_pts_num], device='cuda').bool()]) )
684
+ self.ooi_masks[i] = getattr(self, f"ooi_masks_{i}")
685
+ self.register_buffer(f"_delete_mask", torch.ones_like(self.ooi_masks[0].float()))
686
+
687
+ # project to 3D space
688
+ xyz = xyz
689
+ color = color
690
+ pcd = BasicPointCloud(
691
+ points=xyz, colors=color, normals=np.zeros((xyz.shape[0], 3))
692
+ )
693
+ self.create_from_pcd(pcd, 10)
694
+ self.training_setup()
695
+
696
+ elif self.cfg.geometry_convert_from.startswith("shap-e:"):
697
+ shap_e_guidance = threestudio.find("shap-e-guidance")(
698
+ self.cfg.shap_e_guidance_config
699
+ )
700
+ prompt = self.cfg.geometry_convert_from[len("shap-e:") :]
701
+ xyz, color = shap_e_guidance(prompt)
702
+
703
+ pcd = BasicPointCloud(
704
+ points=xyz, colors=color, normals=np.zeros((xyz.shape[0], 3))
705
+ )
706
+ self.create_from_pcd(pcd, 10)
707
+ self.training_setup()
708
+
709
+ # Support Initialization from OpenLRM, Please see https://github.com/Adamdad/threestudio-lrm
710
+ elif self.cfg.geometry_convert_from.startswith("lrm:"):
711
+ lrm_guidance = threestudio.find("lrm-guidance")(
712
+ self.cfg.shap_e_guidance_config
713
+ )
714
+ prompt = self.cfg.geometry_convert_from[len("lrm:") :]
715
+ xyz, color = lrm_guidance(prompt)
716
+
717
+ pcd = BasicPointCloud(
718
+ points=xyz, colors=color, normals=np.zeros((xyz.shape[0], 3))
719
+ )
720
+ self.create_from_pcd(pcd, 10)
721
+ self.training_setup()
722
+
723
+ elif os.path.exists(self.cfg.geometry_convert_from):
724
+ threestudio.info(
725
+ "Loading point cloud from %s" % self.cfg.geometry_convert_from
726
+ )
727
+ if self.cfg.geometry_convert_from.endswith(".ckpt"):
728
+ ckpt_dict = torch.load(self.cfg.geometry_convert_from)
729
+ num_pts = ckpt_dict["state_dict"]["geometry._xyz"].shape[0]
730
+ pcd = BasicPointCloud(
731
+ points=np.zeros((num_pts, 3)),
732
+ colors=np.zeros((num_pts, 3)),
733
+ normals=np.zeros((num_pts, 3)),
734
+ )
735
+ self.create_from_pcd(pcd, 10)
736
+ self.training_setup()
737
+ new_ckpt_dict = {}
738
+ for key in self.state_dict():
739
+ if ckpt_dict["state_dict"].__contains__("geometry." + key):
740
+ new_ckpt_dict[key] = ckpt_dict["state_dict"]["geometry." + key]
741
+ else:
742
+ new_ckpt_dict[key] = self.state_dict()[key]
743
+ self.load_state_dict(new_ckpt_dict)
744
+ elif self.cfg.geometry_convert_from.endswith(".ply"):
745
+ if self.cfg.load_ply_only_vertex:
746
+ plydata = PlyData.read(self.cfg.geometry_convert_from)
747
+ vertices = plydata["vertex"]
748
+ positions = np.vstack(
749
+ [vertices["x"], vertices["y"], vertices["z"]]
750
+ ).T
751
+ if vertices.__contains__("red"):
752
+ colors = (
753
+ np.vstack(
754
+ [vertices["red"], vertices["green"], vertices["blue"]]
755
+ ).T
756
+ / 255.0
757
+ )
758
+ else:
759
+ shs = np.random.random((positions.shape[0], 3)) / 255.0
760
+ C0 = 0.28209479177387814
761
+ colors = shs * C0 + 0.5
762
+ normals = np.zeros_like(positions)
763
+ pcd = BasicPointCloud(
764
+ points=positions, colors=colors, normals=normals
765
+ )
766
+ self.create_from_pcd(pcd, 10)
767
+ else:
768
+ self.load_ply(self.cfg.geometry_convert_from)
769
+ self.training_setup()
770
+ else:
771
+ threestudio.info("Geometry not found, initilization with random points")
772
+ num_pts = self.cfg.init_num_pts
773
+ phis = np.random.random((num_pts,)) * 2 * np.pi
774
+ costheta = np.random.random((num_pts,)) * 2 - 1
775
+ thetas = np.arccos(costheta)
776
+ mu = np.random.random((num_pts,))
777
+ radius = self.cfg.pc_init_radius * np.cbrt(mu)
778
+ x = radius * np.sin(thetas) * np.cos(phis)
779
+ y = radius * np.sin(thetas) * np.sin(phis)
780
+ z = radius * np.cos(thetas)
781
+ xyz = np.stack((x, y, z), axis=1)
782
+
783
+ shs = np.random.random((num_pts, 3)) / 255.0
784
+ C0 = 0.28209479177387814
785
+ color = shs * C0 + 0.5
786
+ pcd = BasicPointCloud(
787
+ points=xyz, colors=color, normals=np.zeros((num_pts, 3))
788
+ )
789
+
790
+ self.create_from_pcd(pcd, 10)
791
+ self.training_setup()
792
+
793
+ def add_pc_from_novel_view(self, rgb, mask, depth, c2w, save_path=None):
794
+ W, H = parse_wh(self.cfg.img_resolution)
795
+ # depth estimation -> add points.
796
+ mask = fill_mask(mask)
797
+ mask_array = np.array(mask)
798
+ blur_mask = Image.fromarray(cv2.blur(np.array(mask).astype(np.float32), (7, 7)) > 0)
799
+ res = self.inpaint(img=rgb, mask=blur_mask, prompt=self.side_prompt)
800
+
801
+ depth_unaligned = self.geowizard_pipe(
802
+ res,
803
+ denoising_steps = 25,
804
+ ensemble_size = 3,
805
+ processing_res = 768,
806
+ match_input_res = False,
807
+ domain = 'outdoor',
808
+ color_map = 'Spectral',
809
+ gt_depth = None, mask = None,
810
+ show_progress_bar = True)['depth_np']
811
+ prev_depth = depth_unaligned[~np.array(mask.resize((768,768)))]
812
+ # inpaint the depth map
813
+ depth_array = depth[0].cpu().numpy().astype(np.uint8)
814
+ inpaint_mask = (~mask_array & (depth_array == 0)).astype(np.uint8)
815
+ # inpaint_mask = np.logical_and(~np.array(mask.resize((512, 512), Image.NEAREST)) , depth[0].cpu().numpy().astype(np.uint8)==0 ).astype(np.uint8)
816
+ l, r = depth[depth>0].min().item(), depth.max().item()
817
+ depth = (depth - l) / (r - l) * 255
818
+ depth = cv2.inpaint(depth_array, inpaint_mask, 3, cv2.INPAINT_TELEA)
819
+ depth = torch.tensor(depth)[None].cuda().float() / 255
820
+ reproj_func = lambda x: (x - prev_depth.min().item()) / (prev_depth.max().item() - prev_depth.min().item()) * (r-l) + l
821
+ depth = depth * (prev_depth.max() - prev_depth.min()) + prev_depth.min()
822
+ depth_tensor = torch.nn.functional.interpolate(depth[None].cuda(), 768, mode='nearest') * 2 - 1
823
+
824
+ _masks = cv2.blur(np.array(mask.resize((768, 768))).astype(float), (20, 20)) > 0
825
+ mask_tensor = torch.from_numpy(_masks.astype(np.float32))[None, None].cuda()
826
+ bg_xyz_pc, _ = self.img2pc_inpaint(res, gt_depth=depth_tensor, mask=1-mask_tensor, proj_func=reproj_func, c2w=c2w)
827
+
828
+ new_xyz = bg_xyz_pc[mask_array][:, :3]
829
+ res = res.resize((W, H))
830
+ new_color = np.array(res)[mask_array] / 255
831
+ pcd = BasicPointCloud(points=new_xyz, colors=new_color, normals=np.zeros((new_xyz.shape[0], 3)))
832
+ self.merge_from_pcd(pcd, 10)
833
+
834
+ save_pc(save_path, new_xyz, new_color)
835
+ return res, mask
836
+
837
+ @property
838
+ def get_scaling(self):
839
+ if self.cfg.sphere:
840
+ return self.scaling_activation(
841
+ torch.mean(self._scaling, dim=-1).unsqueeze(-1).repeat(1, 3)
842
+ ).clip(0, self.cfg.max_scaling)
843
+ return self.scaling_activation(self._scaling).clip(0, self.cfg.max_scaling)
844
+
845
+ @property
846
+ def get_rotation(self):
847
+ return self.rotation_activation(self._rotation)
848
+
849
+ @property
850
+ def get_language_feature(self):
851
+ return self._language_feature
852
+
853
+ @property
854
+ def get_xyz(self):
855
+ ret = self._xyz
856
+ if self.noise_ratio > 0.0:
857
+ offset = torch.zeros_like(ret)
858
+ for idx in range(len(self.ooi_masks)):
859
+ ooi_masks = getattr(self, f"ooi_masks_{idx}")
860
+ offset[ooi_masks] = torch.rand(3, device='cuda') * self.noise_ratio
861
+ return ret
862
+
863
+ @property
864
+ def get_features(self):
865
+ features_dc = self._features_dc
866
+ features_dc = features_dc.clip(-self.color_clip, self.color_clip)
867
+ features_rest = self._features_rest
868
+ return torch.cat((features_dc, features_rest), dim=1)
869
+
870
+ @property
871
+ def get_opacity(self):
872
+ if self._opacity_mask is None:
873
+ ret = self.opacity_activation(self._opacity)
874
+ else:
875
+ ret = self.opacity_activation(self._opacity) * self._opacity_mask.unsqueeze(-1)
876
+
877
+ if self._delete_mask is None:
878
+ return ret
879
+ else:
880
+ return ret * self._delete_mask.unsqueeze(-1)
881
+
882
+ @property
883
+ def get_normal(self):
884
+ if self.cfg.pred_normal:
885
+ return self._normal
886
+ else:
887
+ raise ValueError("Normal is not predicted")
888
+
889
+ def recover_xyzrot(self):
890
+ self._xyz = torch.nn.Parameter(self.fixed_xyz)
891
+ self._rotation = torch.nn.Parameter(self.fixed_rot)
892
+
893
+ def random_rotate(self, rotate_aug_scale, apply_rotate):
894
+ if self.fixed_xyz is None:
895
+ self.fixed_xyz = self.get_xyz.data
896
+ self.fixed_rot = self.get_rotation.data
897
+
898
+ if apply_rotate:
899
+ ooi_mask = self.ooi_masks_0.view(-1).byte().to(device='cuda').float()
900
+
901
+ rotate = random.randint(-rotate_aug_scale, rotate_aug_scale)
902
+ rot_matrix = rotation_matrix(0, 0, rotate).cuda()
903
+ prev_xyz = self.fixed_xyz.clone()
904
+ ooi_xyz = prev_xyz[ooi_mask.bool()]
905
+ mean = ooi_xyz.mean(0)
906
+ ooi_xyz = ooi_xyz - mean
907
+ after_xyz = torch.einsum('ab,nb->na', rot_matrix, ooi_xyz) + mean
908
+ prev_xyz[ooi_mask.bool()] = after_xyz
909
+ self._xyz = torch.nn.Parameter(prev_xyz)
910
+
911
+ prev_rotation = self.fixed_rot.clone()
912
+ prev_rotation_mtx = build_rotation(prev_rotation)
913
+ after_rotation_mtx = torch.einsum('ab,nbc->nac', rot_matrix, prev_rotation_mtx)
914
+ after_rotation = torch.from_numpy(R.from_matrix(after_rotation_mtx.detach().cpu()).as_quat()).cuda().float()
915
+ after_rotation = torch.einsum('ab,nb->na', REORDER_MTX, after_rotation)
916
+ prev_rotation[ooi_mask.bool()] = after_rotation[ooi_mask.bool()]
917
+ self._rotation = torch.nn.Parameter(prev_rotation)
918
+ else:
919
+ self.recover_xyzrot()
920
+
921
+ def get_covariance(self, scaling_modifier=1):
922
+ return self.covariance_activation(
923
+ self.get_scaling, scaling_modifier, self._rotation
924
+ )
925
+
926
+ def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
927
+ self.spatial_lr_scale = spatial_lr_scale
928
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
929
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
930
+ features = (
931
+ torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))
932
+ .float()
933
+ .cuda()
934
+ )
935
+ features[:, :3, 0] = fused_color
936
+ features[:, 3:, 1:] = 0.0
937
+
938
+ threestudio.info(
939
+ f"Number of points at initialisation:{fused_point_cloud.shape[0]}"
940
+ )
941
+
942
+ dist2 = torch.clamp_min(
943
+ distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),
944
+ 0.0000001,
945
+ )
946
+ scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
947
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
948
+ rots[:, 0] = 1
949
+
950
+ opacities = inverse_sigmoid(
951
+ self.cfg.opacity_init
952
+ * torch.ones(
953
+ (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"
954
+ )
955
+ )
956
+
957
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
958
+ self._features_dc = nn.Parameter(
959
+ features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)
960
+ )
961
+ self._features_rest = nn.Parameter(
962
+ features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)
963
+ )
964
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
965
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
966
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
967
+ if self.cfg.pred_normal:
968
+ normals = torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
969
+ self._normal = nn.Parameter(normals.requires_grad_(True))
970
+ self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
971
+
972
+ self.fused_point_cloud = fused_point_cloud.cpu().clone().detach()
973
+ self.features = features.cpu().clone().detach()
974
+ self.scales = scales.cpu().clone().detach()
975
+ self.rots = rots.cpu().clone().detach()
976
+ self.opacities = opacities.cpu().clone().detach()
977
+
978
+ language_feature = torch.zeros((self._xyz.shape[0], 3), device="cuda")
979
+ self._language_feature = torch.nn.Parameter(language_feature.requires_grad_(True))
980
+
981
+ def merge_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
982
+ self.spatial_lr_scale = spatial_lr_scale
983
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
984
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
985
+ features = (
986
+ torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))
987
+ .float()
988
+ .cuda()
989
+ )
990
+ features[:, :3, 0] = fused_color
991
+ features[:, 3:, 1:] = 0.0
992
+
993
+ threestudio.info(
994
+ f"Number of points at merging:{fused_point_cloud.shape[0]}"
995
+ )
996
+
997
+ dist2 = torch.clamp_min(
998
+ distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),
999
+ 0.0000001,
1000
+ )
1001
+ scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
1002
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
1003
+ rots[:, 0] = 1
1004
+
1005
+ opacities = inverse_sigmoid(
1006
+ self.cfg.opacity_init
1007
+ * torch.ones(
1008
+ (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"
1009
+ )
1010
+ )
1011
+ self.densification_postfix(
1012
+ fused_point_cloud,
1013
+ features[:, :, 0:1].transpose(1, 2).contiguous(),
1014
+ features[:, :, 1:].transpose(1, 2).contiguous(),
1015
+ opacities,
1016
+ scales,
1017
+ rots,
1018
+ None,
1019
+ torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
1020
+ )
1021
+
1022
+ for idx in range(len(self.ooi_masks)):
1023
+ # self.ooi_masks[idx] = torch.cat([self.ooi_masks[idx], torch.ones([fused_point_cloud.shape[0]], device='cuda') > 0])
1024
+ self.register_buffer(f"ooi_masks_{idx}", torch.cat([getattr(self, f"ooi_masks_{idx}"), torch.zeros([fused_point_cloud.shape[0]], device='cuda').bool()]) )
1025
+ self.ooi_masks[idx] = getattr(self, f"ooi_masks_{idx}")
1026
+ self.register_buffer(f"_delete_mask", torch.ones_like(self.ooi_masks[0].float()))
1027
+
1028
+ # self._xyz = torch.nn.Parameter(torch.cat([self._xyz, fused_point_cloud],0),requires_grad=True)
1029
+ # self._features_dc = torch.nn.Parameter(torch.cat([self._features_dc, features[:, :, 0:1].transpose(1, 2).contiguous()],0),requires_grad=True)
1030
+ # self._features_rest = torch.nn.Parameter(torch.cat([self._features_rest, features[:, :, 1:].transpose(1, 2).contiguous()],0),requires_grad=True)
1031
+ # self._scaling = torch.nn.Parameter(torch.cat([self._scaling, scales],0),requires_grad=True)
1032
+ # self._rotation = torch.nn.Parameter(torch.cat([self._rotation, rots],0),requires_grad=True)
1033
+ # self._opacity = torch.nn.Parameter(torch.cat([self._opacity, opacities],0),requires_grad=True)
1034
+
1035
+ # if self.cfg.pred_normal:
1036
+ # normals = torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
1037
+ # self._normal = nn.Parameter(normals.requires_grad_(True))
1038
+ # self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
1039
+
1040
+ # self.fused_point_cloud = fused_point_cloud.cpu().clone().detach()
1041
+ # self.features = features.cpu().clone().detach()
1042
+ # self.scales = scales.cpu().clone().detach()
1043
+ # self.rots = rots.cpu().clone().detach()
1044
+ # self.opacities = opacities.cpu().clone().detach()
1045
+
1046
+ # language_feature = torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
1047
+ # self._language_feature = torch.nn.Parameter(torch.cat([self._language_feature, language_feature], 0), requires_grad=True)
1048
+ # self.training_setup()
1049
+
1050
+
1051
+ def lang_training_setup(self):
1052
+ training_args = self.cfg
1053
+ l = [
1054
+ {'params': [self._language_feature], 'lr': C(training_args.lang_lr, 0, 0)},
1055
+ ]
1056
+ self._xyz.requires_grad_(False)
1057
+ self._features_dc.requires_grad_(False)
1058
+ self._features_rest.requires_grad_(False)
1059
+ self._scaling.requires_grad_(False)
1060
+ self._rotation.requires_grad_(False)
1061
+ self._opacity.requires_grad_(False)
1062
+ self._language_feature.requires_grad_(True)
1063
+ # self.lang_optimizer = torch.optim.SGD(l, lr=0.0)
1064
+ self.lang_optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15, betas=(self.cfg.lang_beta_1, self.cfg.lang_beta_2))
1065
+ self.optimize_params = ["lang"]
1066
+ self.optimize_list = l
1067
+
1068
+ def after_lang(self):
1069
+ self._xyz.requires_grad_(True)
1070
+ self._features_dc.requires_grad_(True)
1071
+ self._features_rest.requires_grad_(True)
1072
+ self._scaling.requires_grad_(True)
1073
+ self._rotation.requires_grad_(True)
1074
+ self._opacity.requires_grad_(True)
1075
+ self._language_feature.requires_grad_(False)
1076
+
1077
+ def training_setup(self):
1078
+ self._xyz.requires_grad_(True)
1079
+ self._features_dc.requires_grad_(True)
1080
+ self._features_rest.requires_grad_(True)
1081
+ self._scaling.requires_grad_(True)
1082
+ self._rotation.requires_grad_(True)
1083
+ self._opacity.requires_grad_(True)
1084
+ self._language_feature.requires_grad_(False)
1085
+ training_args = self.cfg
1086
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
1087
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
1088
+
1089
+ l = [
1090
+ {
1091
+ "params": [self._xyz],
1092
+ "lr": C(training_args.position_lr, 0, 0),
1093
+ "name": "xyz",
1094
+ },
1095
+ {
1096
+ "params": [self._features_dc],
1097
+ "lr": C(training_args.feature_lr, 0, 0),
1098
+ "name": "f_dc",
1099
+ },
1100
+ {
1101
+ "params": [self._features_rest],
1102
+ "lr": C(training_args.feature_lr, 0, 0) / 20.0,
1103
+ "name": "f_rest",
1104
+ },
1105
+ {
1106
+ "params": [self._opacity],
1107
+ "lr": C(training_args.opacity_lr, 0, 0),
1108
+ "name": "opacity",
1109
+ },
1110
+ {
1111
+ "params": [self._scaling],
1112
+ "lr": C(training_args.scaling_lr, 0, 0),
1113
+ "name": "scaling",
1114
+ },
1115
+ {
1116
+ "params": [self._rotation],
1117
+ "lr": C(training_args.rotation_lr, 0, 0),
1118
+ "name": "rotation",
1119
+ },
1120
+ {'params': [self._language_feature], 'lr': C(training_args.lang_lr, 0, 0), "name": "language_feature"},
1121
+ ]
1122
+ if self.cfg.pred_normal:
1123
+ l.append(
1124
+ {
1125
+ "params": [self._normal],
1126
+ "lr": C(training_args.normal_lr, 0, 0),
1127
+ "name": "normal",
1128
+ },
1129
+ )
1130
+
1131
+ self.optimize_params = [
1132
+ "xyz",
1133
+ "f_dc",
1134
+ "f_rest",
1135
+ "opacity",
1136
+ "scaling",
1137
+ "rotation",
1138
+ "language_feature"
1139
+ ]
1140
+ self.optimize_list = l
1141
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
1142
+ self.lang_optimizer = None
1143
+
1144
+ def merge_optimizer(self, net_optimizer):
1145
+ l = self.optimize_list
1146
+ for param in net_optimizer.param_groups:
1147
+ l.append(
1148
+ {
1149
+ "params": param["params"],
1150
+ "lr": param["lr"],
1151
+ }
1152
+ )
1153
+ self.optimizer = torch.optim.Adam(l, lr=0.0)
1154
+ return self.optimizer
1155
+
1156
+ def update_learning_rate(self, iteration):
1157
+ """Learning rate scheduling per step"""
1158
+ for param_group in self.optimizer.param_groups:
1159
+ if not ("name" in param_group):
1160
+ continue
1161
+ if param_group["name"] == "xyz":
1162
+ param_group["lr"] = C(
1163
+ self.cfg.position_lr, 0, iteration, interpolation="exp"
1164
+ )
1165
+ if param_group["name"] == "scaling":
1166
+ param_group["lr"] = C(
1167
+ self.cfg.scaling_lr, 0, iteration, interpolation="exp"
1168
+ )
1169
+ if param_group["name"] == "f_dc":
1170
+ param_group["lr"] = C(
1171
+ self.cfg.feature_lr, 0, iteration, interpolation="exp"
1172
+ )
1173
+ if param_group["name"] == "f_rest":
1174
+ param_group["lr"] = (
1175
+ C(self.cfg.feature_lr, 0, iteration, interpolation="exp") / 20.0
1176
+ )
1177
+ if param_group["name"] == "opacity":
1178
+ param_group["lr"] = C(
1179
+ self.cfg.opacity_lr, 0, iteration, interpolation="exp"
1180
+ )
1181
+ if param_group["name"] == "rotation":
1182
+ param_group["lr"] = C(
1183
+ self.cfg.rotation_lr, 0, iteration, interpolation="exp"
1184
+ )
1185
+ if param_group["name"] == "normal":
1186
+ param_group["lr"] = C(
1187
+ self.cfg.normal_lr, 0, iteration, interpolation="exp"
1188
+ )
1189
+ if self.lang_optimizer is not None:
1190
+ for param_group in self.lang_optimizer.param_groups:
1191
+ if not ("name" in param_group):
1192
+ continue
1193
+ if param_group["name"] == "language_feature":
1194
+ param_group["lr"] = C(
1195
+ self.cfg.lang_lr, 0, iteration, interpolation="exp"
1196
+ )
1197
+ self.color_clip = C(self.cfg.color_clip, 0, iteration)
1198
+
1199
+ def reset_opacity(self):
1200
+ # opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
1201
+ opacities_new = inverse_sigmoid(self.get_opacity * 0.9)
1202
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
1203
+ self._opacity = optimizable_tensors["opacity"]
1204
+
1205
+ def to(self, device="cpu"):
1206
+ self._xyz = self._xyz.to(device)
1207
+ self._features_dc = self._features_dc.to(device)
1208
+ self._features_rest = self._features_rest.to(device)
1209
+ self._opacity = self._opacity.to(device)
1210
+ self._scaling = self._scaling.to(device)
1211
+ self._rotation = self._rotation.to(device)
1212
+ self._normal = self._normal.to(device)
1213
+ self._language_feature = self._language_feature.to(device)
1214
+
1215
+ def replace_tensor_to_optimizer(self, tensor, name):
1216
+ optimizable_tensors = {}
1217
+ for group in self.optimizer.param_groups:
1218
+ if ("name" in group) and group["name"] == name:
1219
+ stored_state = self.optimizer.state.get(group["params"][0], None)
1220
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
1221
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
1222
+
1223
+ del self.optimizer.state[group["params"][0]]
1224
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
1225
+ self.optimizer.state[group["params"][0]] = stored_state
1226
+
1227
+ optimizable_tensors[group["name"]] = group["params"][0]
1228
+ return optimizable_tensors
1229
+
1230
+ def _prune_optimizer(self, mask):
1231
+ optimizable_tensors = {}
1232
+ for group in self.optimizer.param_groups:
1233
+ if ("name" in group) and (group["name"] in self.optimize_params):
1234
+ stored_state = self.optimizer.state.get(group["params"][0], None)
1235
+ if stored_state is not None:
1236
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
1237
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
1238
+
1239
+ del self.optimizer.state[group["params"][0]]
1240
+ group["params"][0] = nn.Parameter(
1241
+ (group["params"][0][mask].requires_grad_(True))
1242
+ )
1243
+ self.optimizer.state[group["params"][0]] = stored_state
1244
+
1245
+ optimizable_tensors[group["name"]] = group["params"][0]
1246
+ else:
1247
+ group["params"][0] = nn.Parameter(
1248
+ group["params"][0][mask].requires_grad_(True)
1249
+ )
1250
+ optimizable_tensors[group["name"]] = group["params"][0]
1251
+ return optimizable_tensors
1252
+
1253
+ def prune_points(self, mask):
1254
+ valid_points_mask = ~mask
1255
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
1256
+
1257
+ self._xyz = optimizable_tensors["xyz"]
1258
+ self._features_dc = optimizable_tensors["f_dc"]
1259
+ self._features_rest = optimizable_tensors["f_rest"]
1260
+ self._opacity = optimizable_tensors["opacity"]
1261
+ self._scaling = optimizable_tensors["scaling"]
1262
+ self._rotation = optimizable_tensors["rotation"]
1263
+ self._language_feature = optimizable_tensors["language_feature"]
1264
+ if self.cfg.pred_normal:
1265
+ self._normal = optimizable_tensors["normal"]
1266
+
1267
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
1268
+
1269
+ self.denom = self.denom[valid_points_mask]
1270
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
1271
+
1272
+ def cat_tensors_to_optimizer(self, tensors_dict):
1273
+ optimizable_tensors = {}
1274
+ for group in self.optimizer.param_groups:
1275
+ if ("name" in group) and (group["name"] in self.optimize_params):
1276
+ extension_tensor = tensors_dict[group["name"]]
1277
+ stored_state = self.optimizer.state.get(group["params"][0], None)
1278
+ if stored_state is not None:
1279
+ stored_state["exp_avg"] = torch.cat(
1280
+ (stored_state["exp_avg"], torch.zeros_like(extension_tensor)),
1281
+ dim=0,
1282
+ )
1283
+ stored_state["exp_avg_sq"] = torch.cat(
1284
+ (
1285
+ stored_state["exp_avg_sq"],
1286
+ torch.zeros_like(extension_tensor),
1287
+ ),
1288
+ dim=0,
1289
+ )
1290
+
1291
+ del self.optimizer.state[group["params"][0]]
1292
+ group["params"][0] = nn.Parameter(
1293
+ torch.cat(
1294
+ (group["params"][0], extension_tensor), dim=0
1295
+ ).requires_grad_(True)
1296
+ )
1297
+ self.optimizer.state[group["params"][0]] = stored_state
1298
+
1299
+ optimizable_tensors[group["name"]] = group["params"][0]
1300
+ else:
1301
+ group["params"][0] = nn.Parameter(
1302
+ torch.cat(
1303
+ (group["params"][0], extension_tensor), dim=0
1304
+ ).requires_grad_(True)
1305
+ )
1306
+ optimizable_tensors[group["name"]] = group["params"][0]
1307
+
1308
+ return optimizable_tensors
1309
+
1310
+ def densification_postfix(
1311
+ self,
1312
+ new_xyz,
1313
+ new_features_dc,
1314
+ new_features_rest,
1315
+ new_opacities,
1316
+ new_scaling,
1317
+ new_rotation,
1318
+ new_normal=None,
1319
+ new_language_feature=None
1320
+ ):
1321
+ d = {
1322
+ "xyz": new_xyz,
1323
+ "f_dc": new_features_dc,
1324
+ "f_rest": new_features_rest,
1325
+ "opacity": new_opacities,
1326
+ "scaling": new_scaling,
1327
+ "rotation": new_rotation,
1328
+ "language_feature": new_language_feature,
1329
+ }
1330
+ if self.cfg.pred_normal:
1331
+ d.update({"normal": new_normal})
1332
+
1333
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
1334
+ self._xyz = optimizable_tensors["xyz"]
1335
+ self._features_dc = optimizable_tensors["f_dc"]
1336
+ self._features_rest = optimizable_tensors["f_rest"]
1337
+ self._opacity = optimizable_tensors["opacity"]
1338
+ self._scaling = optimizable_tensors["scaling"]
1339
+ self._rotation = optimizable_tensors["rotation"]
1340
+ self._language_feature = optimizable_tensors["language_feature"]
1341
+ if self.cfg.pred_normal:
1342
+ self._normal = optimizable_tensors["normal"]
1343
+
1344
+ self.xyz_gradient_accum = torch.zeros((self._xyz.shape[0], 1), device="cuda")
1345
+ self.denom = torch.zeros((self._xyz.shape[0], 1), device="cuda")
1346
+ self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
1347
+
1348
+ def densify_and_split(self, grads, grad_threshold, N=2):
1349
+ n_init_points = self._xyz.shape[0]
1350
+ # Extract points that satisfy the gradient condition
1351
+ padded_grad = torch.zeros((n_init_points), device="cuda")
1352
+ padded_grad[: grads.shape[0]] = grads.squeeze()
1353
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
1354
+ selected_pts_mask = torch.logical_and(
1355
+ selected_pts_mask,
1356
+ torch.norm(self.get_scaling, dim=1) > self.cfg.split_thresh,
1357
+ )
1358
+
1359
+ # divide N to enhance robustness
1360
+ stds = self.get_scaling[selected_pts_mask].repeat(N, 1) / N
1361
+ means = torch.zeros((stds.size(0), 3), device="cuda")
1362
+ samples = torch.normal(mean=means, std=stds)
1363
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)
1364
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self._xyz[
1365
+ selected_pts_mask
1366
+ ].repeat(N, 1)
1367
+ new_scaling = self.scaling_inverse_activation(
1368
+ self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)
1369
+ )
1370
+ new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
1371
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
1372
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)
1373
+ new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)
1374
+ new_language_feature = self._language_feature[selected_pts_mask].repeat(N,1)
1375
+ if self.cfg.pred_normal:
1376
+ new_normal = self._normal[selected_pts_mask].repeat(N, 1)
1377
+ else:
1378
+ new_normal = None
1379
+
1380
+ self.densification_postfix(
1381
+ new_xyz,
1382
+ new_features_dc,
1383
+ new_features_rest,
1384
+ new_opacity,
1385
+ new_scaling,
1386
+ new_rotation,
1387
+ new_normal,
1388
+ new_language_feature
1389
+ )
1390
+
1391
+ prune_filter = torch.cat(
1392
+ (
1393
+ selected_pts_mask,
1394
+ torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool),
1395
+ )
1396
+ )
1397
+ self.prune_points(prune_filter)
1398
+
1399
+ def densify_and_clone(self, grads, grad_threshold):
1400
+ # Extract points that satisfy the gradient condition
1401
+ selected_pts_mask = torch.where(
1402
+ torch.norm(grads, dim=-1) >= grad_threshold, True, False
1403
+ )
1404
+ selected_pts_mask = torch.logical_and(
1405
+ selected_pts_mask,
1406
+ torch.norm(self.get_scaling, dim=1) <= self.cfg.split_thresh,
1407
+ )
1408
+
1409
+ new_xyz = self._xyz[selected_pts_mask]
1410
+ new_features_dc = self._features_dc[selected_pts_mask]
1411
+ new_features_rest = self._features_rest[selected_pts_mask]
1412
+ new_opacities = self._opacity[selected_pts_mask]
1413
+ new_scaling = self._scaling[selected_pts_mask]
1414
+ new_rotation = self._rotation[selected_pts_mask]
1415
+ new_language_feature = self._language_feature[selected_pts_mask]
1416
+ if self.cfg.pred_normal:
1417
+ new_normal = self._normal[selected_pts_mask]
1418
+ else:
1419
+ new_normal = None
1420
+
1421
+ self.densification_postfix(
1422
+ new_xyz,
1423
+ new_features_dc,
1424
+ new_features_rest,
1425
+ new_opacities,
1426
+ new_scaling,
1427
+ new_rotation,
1428
+ new_normal,
1429
+ new_language_feature
1430
+ )
1431
+
1432
+ def densify(self, max_grad):
1433
+ grads = self.xyz_gradient_accum / self.denom
1434
+ grads[grads.isnan()] = 0.0
1435
+
1436
+ self.densify_and_clone(grads, max_grad)
1437
+ self.densify_and_split(grads, max_grad)
1438
+
1439
+ def prune(self, min_opacity, max_screen_size):
1440
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
1441
+ if self.cfg.prune_big_points:
1442
+ big_points_vs = self.max_radii2D > (torch.mean(self.max_radii2D) * 3)
1443
+ prune_mask = torch.logical_or(prune_mask, big_points_vs)
1444
+ self.prune_points(prune_mask)
1445
+
1446
+ torch.cuda.empty_cache()
1447
+
1448
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
1449
+ self.xyz_gradient_accum[update_filter] += torch.norm(
1450
+ viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True
1451
+ )
1452
+ self.denom[update_filter] += 1
1453
+
1454
+ @torch.no_grad()
1455
+ def update_states(
1456
+ self,
1457
+ iteration,
1458
+ visibility_filter,
1459
+ radii,
1460
+ viewspace_point_tensor,
1461
+ ):
1462
+ if self._xyz.shape[0] >= self.cfg.max_num + 100:
1463
+ prune_mask = torch.randperm(self._xyz.shape[0]).to(self._xyz.device)
1464
+ prune_mask = prune_mask > self.cfg.max_num
1465
+ self.prune_points(prune_mask)
1466
+ return
1467
+ # Keep track of max radii in image-space for pruning
1468
+ # loop over batch
1469
+ bs = len(viewspace_point_tensor)
1470
+ for i in range(bs):
1471
+ radii_i = radii[i]
1472
+ visibility_filter_i = visibility_filter[i]
1473
+ viewspace_point_tensor_i = viewspace_point_tensor[i]
1474
+ self.max_radii2D = torch.max(self.max_radii2D, radii_i.float())
1475
+
1476
+ self.add_densification_stats(viewspace_point_tensor_i, visibility_filter_i)
1477
+
1478
+ if (
1479
+ iteration > self.cfg.prune_from_iter
1480
+ and iteration < self.cfg.prune_until_iter
1481
+ and iteration % self.cfg.prune_interval == 0
1482
+ ):
1483
+ self.prune(self.cfg.min_opac_prune, self.cfg.radii2d_thresh)
1484
+ if iteration % self.cfg.opacity_reset_interval == 0:
1485
+ self.reset_opacity()
1486
+
1487
+ if (
1488
+ iteration > self.cfg.densify_from_iter
1489
+ and iteration < self.cfg.densify_until_iter
1490
+ and iteration % self.cfg.densification_interval == 0
1491
+ ):
1492
+ self.densify(self.cfg.densify_grad_threshold)
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/gaussian_dynamic.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+ import math
12
+ import os
13
+ import random
14
+ import sys
15
+ from dataclasses import dataclass, field
16
+ from datetime import datetime
17
+ from typing import NamedTuple
18
+
19
+ import numpy as np
20
+ import threestudio
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from plyfile import PlyData, PlyElement
25
+ from simple_knn._C import distCUDA2
26
+ from threestudio.models.geometry.base import BaseGeometry
27
+ from threestudio.utils.misc import C
28
+ from threestudio.utils.typing import *
29
+
30
+ from .gaussian_base import GaussianBaseModel
31
+
32
+
33
+ @threestudio.register("gaussian-splatting-dynamic")
34
+ class GaussianDynamicModel(GaussianBaseModel):
35
+ @dataclass
36
+ class Config(GaussianBaseModel.Config):
37
+ flow: bool = True
38
+ num_frames: int = 10
39
+ delta_pos_lr: float = 0.001
40
+ delta_rot_lr: float = 0.0001
41
+
42
+ cfg: Config
43
+
44
+ def configure(self) -> None:
45
+ super().configure()
46
+ self._delta_xyz = torch.empty(0)
47
+ self._delta_rot = torch.empty(0)
48
+ self.time_index = 0
49
+
50
+ def training_setup(self):
51
+ super().training_setup()
52
+ l = self.optimize_list
53
+ training_args = self.cfg
54
+ l.append(
55
+ {
56
+ "params": [self._delta_xyz],
57
+ "lr": C(training_args.delta_pos_lr, 0, 0),
58
+ "name": "normal",
59
+ },
60
+ )
61
+ l.append(
62
+ {
63
+ "params": [self._delta_rot],
64
+ "lr": C(training_args.delta_rot_lr, 0, 0),
65
+ "name": "normal",
66
+ },
67
+ )
68
+
69
+ @property
70
+ def get_rotation(self):
71
+ return self.rotation_activation(
72
+ self._rotation + self._delta_rot[self.time_index]
73
+ )
74
+
75
+ @property
76
+ def get_xyz(self):
77
+ return self._xyz + self._delta_xyz[self.time_index]
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/gaussian_io.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+ import os
12
+ import random
13
+ import sys
14
+ from dataclasses import dataclass, field
15
+ from datetime import datetime
16
+ from typing import NamedTuple
17
+
18
+ import mcubes
19
+ import numpy as np
20
+ import threestudio
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from plyfile import PlyData, PlyElement
25
+ from simple_knn._C import distCUDA2
26
+ from threestudio.models.geometry.base import BaseGeometry
27
+ from threestudio.models.mesh import Mesh
28
+ from threestudio.utils.typing import *
29
+ from tqdm import tqdm
30
+
31
+ from .mesh_utils import *
32
+
33
+
34
+ def gaussian_3d_coeff(xyzs, covs):
35
+ # xyzs: [N, 3]
36
+ # covs: [N, 6]
37
+ x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
38
+ a, b, c, d, e, f = (
39
+ covs[:, 0],
40
+ covs[:, 1],
41
+ covs[:, 2],
42
+ covs[:, 3],
43
+ covs[:, 4],
44
+ covs[:, 5],
45
+ )
46
+
47
+ # eps must be small enough !!!
48
+ inv_det = 1 / (
49
+ a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24
50
+ )
51
+ inv_a = (d * f - e**2) * inv_det
52
+ inv_b = (e * c - b * f) * inv_det
53
+ inv_c = (e * b - c * d) * inv_det
54
+ inv_d = (a * f - c**2) * inv_det
55
+ inv_e = (b * c - e * a) * inv_det
56
+ inv_f = (a * d - b**2) * inv_det
57
+
58
+ power = (
59
+ -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f)
60
+ - x * y * inv_b
61
+ - x * z * inv_c
62
+ - y * z * inv_e
63
+ )
64
+
65
+ power[power > 0] = -1e10 # abnormal values... make weights 0
66
+
67
+ return torch.exp(power)
68
+
69
+
70
+ @threestudio.register("gaussian-splatting-io")
71
+ class GaussianIO:
72
+ def construct_list_of_attributes(self):
73
+ l = ["x", "y", "z", "nx", "ny", "nz"]
74
+ # All channels except the 3 DC
75
+ for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
76
+ l.append("f_dc_{}".format(i))
77
+ for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
78
+ l.append("f_rest_{}".format(i))
79
+ l.append("opacity")
80
+ for i in range(self._scaling.shape[1]):
81
+ l.append("scale_{}".format(i))
82
+ for i in range(self._rotation.shape[1]):
83
+ l.append("rot_{}".format(i))
84
+ return l
85
+
86
+ def save_ply(self, path):
87
+ xyz = self._xyz.detach().cpu().numpy()
88
+ normals = np.zeros_like(xyz)
89
+ f_dc = (
90
+ self._features_dc.detach()
91
+ .transpose(1, 2)
92
+ .flatten(start_dim=1)
93
+ .contiguous()
94
+ .cpu()
95
+ .numpy()
96
+ )
97
+ f_rest = (
98
+ self._features_rest.detach()
99
+ .transpose(1, 2)
100
+ .flatten(start_dim=1)
101
+ .contiguous()
102
+ .cpu()
103
+ .numpy()
104
+ )
105
+ opacities = self._opacity.detach().cpu().numpy()
106
+ scale = self._scaling.detach().cpu().numpy()
107
+ rotation = self._rotation.detach().cpu().numpy()
108
+
109
+ dtype_full = [
110
+ (attribute, "f4") for attribute in self.construct_list_of_attributes()
111
+ ]
112
+
113
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
114
+ attributes = np.concatenate(
115
+ (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1
116
+ )
117
+ elements[:] = list(map(tuple, attributes))
118
+ el = PlyElement.describe(elements, "vertex")
119
+ PlyData([el]).write(path)
120
+
121
+ def load_ply(self, path):
122
+ plydata = PlyData.read(path)
123
+
124
+ xyz = np.stack(
125
+ (
126
+ np.asarray(plydata.elements[0]["x"]),
127
+ np.asarray(plydata.elements[0]["y"]),
128
+ np.asarray(plydata.elements[0]["z"]),
129
+ ),
130
+ axis=1,
131
+ )
132
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
133
+
134
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
135
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
136
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
137
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
138
+
139
+ if self.max_sh_degree > 0:
140
+ extra_f_names = [
141
+ p.name
142
+ for p in plydata.elements[0].properties
143
+ if p.name.startswith("f_rest_")
144
+ ]
145
+ extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
146
+ assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3
147
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
148
+ for idx, attr_name in enumerate(extra_f_names):
149
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
150
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
151
+ features_extra = features_extra.reshape(
152
+ (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)
153
+ )
154
+
155
+ scale_names = [
156
+ p.name
157
+ for p in plydata.elements[0].properties
158
+ if p.name.startswith("scale_")
159
+ ]
160
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
161
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
162
+ for idx, attr_name in enumerate(scale_names):
163
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
164
+
165
+ rot_names = [
166
+ p.name for p in plydata.elements[0].properties if p.name.startswith("rot")
167
+ ]
168
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
169
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
170
+ for idx, attr_name in enumerate(rot_names):
171
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
172
+
173
+ self._xyz = nn.Parameter(
174
+ torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)
175
+ )
176
+ self._features_dc = nn.Parameter(
177
+ torch.tensor(features_dc, dtype=torch.float, device="cuda")
178
+ .transpose(1, 2)
179
+ .contiguous()
180
+ .requires_grad_(True)
181
+ )
182
+ if self.max_sh_degree > 0:
183
+ self._features_rest = nn.Parameter(
184
+ torch.tensor(features_extra, dtype=torch.float, device="cuda")
185
+ .transpose(1, 2)
186
+ .contiguous()
187
+ .requires_grad_(True)
188
+ )
189
+ else:
190
+ self._features_rest = nn.Parameter(
191
+ torch.tensor(features_dc, dtype=torch.float, device="cuda")[:, :, 1:]
192
+ .transpose(1, 2)
193
+ .contiguous()
194
+ .requires_grad_(True)
195
+ )
196
+ self._opacity = nn.Parameter(
197
+ torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(
198
+ True
199
+ )
200
+ )
201
+ self._scaling = nn.Parameter(
202
+ torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)
203
+ )
204
+ self._rotation = nn.Parameter(
205
+ torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)
206
+ )
207
+ self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
208
+ self.active_sh_degree = self.max_sh_degree
209
+
210
+ @torch.no_grad()
211
+ def extract_fields(self, resolution=128, num_blocks=16, relax_ratio=1.5):
212
+ # resolution: resolution of field
213
+
214
+ block_size = 2 / num_blocks
215
+
216
+ assert resolution % block_size == 0
217
+ split_size = resolution // num_blocks
218
+
219
+ opacities = self.get_opacity
220
+
221
+ # pre-filter low opacity gaussians to save computation
222
+ mask = (opacities > 0.005).squeeze(1)
223
+
224
+ opacities = opacities[mask]
225
+ xyzs = self.get_xyz[mask]
226
+ stds = self.get_scaling[mask]
227
+
228
+ # normalize to ~ [-1, 1]
229
+ mn, mx = xyzs.amin(0), xyzs.amax(0)
230
+ self.center = (mn + mx) / 2
231
+ self.scale = 1.8 / (mx - mn).amax().item()
232
+
233
+ xyzs = (xyzs - self.center) * self.scale
234
+ stds = stds * self.scale
235
+
236
+ covs = self.covariance_activation(stds, 1, self._rotation[mask])
237
+
238
+ # tile
239
+ device = opacities.device
240
+ occ = torch.zeros([resolution] * 3, dtype=torch.float32, device=device)
241
+
242
+ X = torch.linspace(-1, 1, resolution).split(split_size)
243
+ Y = torch.linspace(-1, 1, resolution).split(split_size)
244
+ Z = torch.linspace(-1, 1, resolution).split(split_size)
245
+
246
+ # loop blocks (assume max size of gaussian is small than relax_ratio * block_size !!!)
247
+ for xi, xs in tqdm(enumerate(X)):
248
+ for yi, ys in enumerate(Y):
249
+ for zi, zs in enumerate(Z):
250
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
251
+ # sample points [M, 3]
252
+ pts = torch.cat(
253
+ [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
254
+ dim=-1,
255
+ ).to(device)
256
+ # in-tile gaussians mask
257
+ vmin, vmax = pts.amin(0), pts.amax(0)
258
+ vmin -= block_size * relax_ratio
259
+ vmax += block_size * relax_ratio
260
+ mask = (xyzs < vmax).all(-1) & (xyzs > vmin).all(-1)
261
+ # if hit no gaussian, continue to next block
262
+ if not mask.any():
263
+ continue
264
+ mask_xyzs = xyzs[mask] # [L, 3]
265
+ mask_covs = covs[mask] # [L, 6]
266
+ mask_opas = opacities[mask].view(1, -1) # [L, 1] --> [1, L]
267
+
268
+ # query per point-gaussian pair.
269
+ g_pts = pts.unsqueeze(1).repeat(
270
+ 1, mask_covs.shape[0], 1
271
+ ) - mask_xyzs.unsqueeze(
272
+ 0
273
+ ) # [M, L, 3]
274
+ g_covs = mask_covs.unsqueeze(0).repeat(
275
+ pts.shape[0], 1, 1
276
+ ) # [M, L, 6]
277
+
278
+ # batch on gaussian to avoid OOM
279
+ batch_g = 1024
280
+ val = 0
281
+ for start in range(0, g_covs.shape[1], batch_g):
282
+ end = min(start + batch_g, g_covs.shape[1])
283
+ w = gaussian_3d_coeff(
284
+ g_pts[:, start:end].reshape(-1, 3),
285
+ g_covs[:, start:end].reshape(-1, 6),
286
+ ).reshape(
287
+ pts.shape[0], -1
288
+ ) # [M, l]
289
+ val += (mask_opas[:, start:end] * w).sum(-1)
290
+
291
+ # kiui.lo(val, mask_opas, w)
292
+
293
+ occ[
294
+ xi * split_size : xi * split_size + len(xs),
295
+ yi * split_size : yi * split_size + len(ys),
296
+ zi * split_size : zi * split_size + len(zs),
297
+ ] = val.reshape(len(xs), len(ys), len(zs))
298
+
299
+ # kiui.lo(occ, verbose=1)
300
+
301
+ return occ
302
+
303
+ def extract_mesh(self, density_thresh=0.8, resolution=128, decimate_target=1e5):
304
+ occ = self.extract_fields(resolution).detach().cpu().numpy()
305
+
306
+ vertices, triangles = mcubes.marching_cubes(occ, density_thresh)
307
+ vertices = vertices / (resolution - 1.0) * 2 - 1
308
+
309
+ # transform back to the original space
310
+ vertices = vertices / self.scale + self.center.detach().cpu().numpy()
311
+
312
+ vertices, triangles = clean_mesh(
313
+ vertices, triangles, remesh=True, remesh_size=0.015
314
+ )
315
+ if decimate_target > 0 and triangles.shape[0] > decimate_target:
316
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target)
317
+
318
+ v = torch.from_numpy(vertices.astype(np.float32)).contiguous().cuda()
319
+ f = torch.from_numpy(triangles.astype(np.int32)).contiguous().cuda()
320
+
321
+ threestudio.info(
322
+ f"marching cubes result: {v.shape} ({v.min().item()}-{v.max().item()}), {f.shape}"
323
+ )
324
+
325
+ mesh = Mesh(v_pos=v, t_pos_idx=f)
326
+
327
+ return mesh
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/geometry/mesh_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import threestudio
3
+
4
+
5
+ def poisson_mesh_reconstruction(points, normals=None):
6
+ # points/normals: [N, 3] np.ndarray
7
+
8
+ import open3d as o3d
9
+
10
+ pcd = o3d.geometry.PointCloud()
11
+ pcd.points = o3d.utility.Vector3dVector(points)
12
+
13
+ # outlier removal
14
+ pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
15
+
16
+ # normals
17
+ if normals is None:
18
+ pcd.estimate_normals()
19
+ else:
20
+ pcd.normals = o3d.utility.Vector3dVector(normals[ind])
21
+
22
+ # visualize
23
+ o3d.visualization.draw_geometries([pcd], point_show_normal=False)
24
+
25
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
26
+ pcd, depth=9
27
+ )
28
+ vertices_to_remove = densities < np.quantile(densities, 0.1)
29
+ mesh.remove_vertices_by_mask(vertices_to_remove)
30
+
31
+ # visualize
32
+ o3d.visualization.draw_geometries([mesh])
33
+
34
+ vertices = np.asarray(mesh.vertices)
35
+ triangles = np.asarray(mesh.triangles)
36
+
37
+ print(
38
+ f"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}"
39
+ )
40
+
41
+ return vertices, triangles
42
+
43
+
44
+ def decimate_mesh(
45
+ verts, faces, target, backend="pymeshlab", remesh=False, optimalplacement=True
46
+ ):
47
+ # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
48
+
49
+ _ori_vert_shape = verts.shape
50
+ _ori_face_shape = faces.shape
51
+
52
+ if backend == "pyfqmr":
53
+ import pyfqmr
54
+
55
+ solver = pyfqmr.Simplify()
56
+ solver.setMesh(verts, faces)
57
+ solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
58
+ verts, faces, normals = solver.getMesh()
59
+ else:
60
+ import pymeshlab as pml
61
+
62
+ m = pml.Mesh(verts, faces)
63
+ ms = pml.MeshSet()
64
+ ms.add_mesh(m, "mesh") # will copy!
65
+
66
+ # filters
67
+ # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1))
68
+ ms.meshing_decimation_quadric_edge_collapse(
69
+ targetfacenum=int(target), optimalplacement=optimalplacement
70
+ )
71
+
72
+ if remesh:
73
+ # ms.apply_coord_taubin_smoothing()
74
+ ms.meshing_isotropic_explicit_remeshing(
75
+ iterations=3, targetlen=pml.PercentageValue(1)
76
+ )
77
+
78
+ # extract mesh
79
+ m = ms.current_mesh()
80
+ verts = m.vertex_matrix()
81
+ faces = m.face_matrix()
82
+
83
+ print(
84
+ f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}"
85
+ )
86
+
87
+ return verts, faces
88
+
89
+
90
+ def clean_mesh(
91
+ verts,
92
+ faces,
93
+ v_pct=1,
94
+ min_f=64,
95
+ min_d=20,
96
+ repair=True,
97
+ remesh=True,
98
+ remesh_size=0.01,
99
+ ):
100
+ # verts: [N, 3]
101
+ # faces: [N, 3]
102
+ import pymeshlab as pml
103
+
104
+ _ori_vert_shape = verts.shape
105
+ _ori_face_shape = faces.shape
106
+
107
+ m = pml.Mesh(verts, faces)
108
+ ms = pml.MeshSet()
109
+ ms.add_mesh(m, "mesh") # will copy!
110
+
111
+ # filters
112
+ ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
113
+
114
+ if v_pct > 0:
115
+ ms.meshing_merge_close_vertices(
116
+ threshold=pml.PercentageValue(v_pct)
117
+ ) # 1/10000 of bounding box diagonal
118
+
119
+ ms.meshing_remove_duplicate_faces() # faces defined by the same verts
120
+ ms.meshing_remove_null_faces() # faces with area == 0
121
+
122
+ if min_d > 0:
123
+ ms.meshing_remove_connected_component_by_diameter(
124
+ mincomponentdiag=pml.PercentageValue(min_d)
125
+ )
126
+
127
+ if min_f > 0:
128
+ ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
129
+
130
+ if repair:
131
+ # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
132
+ ms.meshing_repair_non_manifold_edges(method=0)
133
+ ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
134
+
135
+ if remesh:
136
+ # ms.apply_coord_taubin_smoothing()
137
+ ms.meshing_isotropic_explicit_remeshing(
138
+ iterations=3, targetlen=pml.PureValue(remesh_size)
139
+ )
140
+
141
+ # extract mesh
142
+ m = ms.current_mesh()
143
+ verts = m.vertex_matrix()
144
+ faces = m.face_matrix()
145
+
146
+ print(
147
+ f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}"
148
+ )
149
+
150
+ return verts, faces
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/material/gaussian_material.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+
4
+ import threestudio
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from threestudio.models.materials.base import BaseMaterial
9
+ from threestudio.utils.ops import dot, get_activation
10
+ from threestudio.utils.typing import *
11
+
12
+
13
+ @threestudio.register("gaussian-diffuse-with-point-light-material")
14
+ class GaussianDiffuseWithPointLightMaterial(BaseMaterial):
15
+ @dataclass
16
+ class Config(BaseMaterial.Config):
17
+ ambient_light_color: Tuple[float, float, float] = (0.1, 0.1, 0.1)
18
+ diffuse_light_color: Tuple[float, float, float] = (0.9, 0.9, 0.9)
19
+ ambient_only_steps: int = 1000
20
+ diffuse_prob: float = 0.75
21
+ textureless_prob: float = 0.5
22
+ soft_shading: bool = False
23
+
24
+ cfg: Config
25
+
26
+ def configure(self) -> None:
27
+ self.requires_normal = True
28
+
29
+ self.ambient_light_color: Float[Tensor, "3"]
30
+ self.register_buffer(
31
+ "ambient_light_color",
32
+ torch.as_tensor(self.cfg.ambient_light_color, dtype=torch.float32),
33
+ )
34
+ self.diffuse_light_color: Float[Tensor, "3"]
35
+ self.register_buffer(
36
+ "diffuse_light_color",
37
+ torch.as_tensor(self.cfg.diffuse_light_color, dtype=torch.float32),
38
+ )
39
+ self.ambient_only = False
40
+
41
+ def forward(
42
+ self,
43
+ positions: Float[Tensor, "B ... 3"],
44
+ shading_normal: Float[Tensor, "B ... 3"],
45
+ light_positions: Float[Tensor, "B ... 3"],
46
+ albedo: Float[Tensor, "B ... 3"],
47
+ ambient_ratio: Optional[float] = None,
48
+ shading: Optional[str] = None,
49
+ **kwargs,
50
+ ) -> Float[Tensor, "B ... 3"]:
51
+ if ambient_ratio is not None:
52
+ # if ambient ratio is specified, use it
53
+ diffuse_light_color = (1 - ambient_ratio) * torch.ones_like(
54
+ self.diffuse_light_color
55
+ )
56
+ ambient_light_color = ambient_ratio * torch.ones_like(
57
+ self.ambient_light_color
58
+ )
59
+ elif self.training and self.cfg.soft_shading:
60
+ # otherwise if in training and soft shading is enabled, random a ambient ratio
61
+ diffuse_light_color = torch.full_like(
62
+ self.diffuse_light_color, random.random()
63
+ )
64
+ ambient_light_color = 1.0 - diffuse_light_color
65
+ else:
66
+ # otherwise use the default fixed values
67
+ diffuse_light_color = self.diffuse_light_color
68
+ ambient_light_color = self.ambient_light_color
69
+
70
+ light_directions: Float[Tensor, "B ... 3"] = F.normalize(
71
+ light_positions - positions, dim=-1
72
+ )
73
+ diffuse_light: Float[Tensor, "B ... 3"] = (
74
+ dot(shading_normal, light_directions).clamp(min=0.0) * diffuse_light_color
75
+ )
76
+ textureless_color = diffuse_light + ambient_light_color
77
+ # clamp albedo to [0, 1] to compute shading
78
+ color = albedo.clamp(0.0, 1.0) * textureless_color
79
+
80
+ if shading is None:
81
+ if self.training:
82
+ # adopt the same type of augmentation for the whole batch
83
+ if self.ambient_only or random.random() > self.cfg.diffuse_prob:
84
+ shading = "albedo"
85
+ elif random.random() < self.cfg.textureless_prob:
86
+ shading = "textureless"
87
+ else:
88
+ shading = "diffuse"
89
+ else:
90
+ if self.ambient_only:
91
+ shading = "albedo"
92
+ else:
93
+ # return shaded color by default in evaluation
94
+ shading = "diffuse"
95
+
96
+ # multiply by 0 to prevent checking for unused parameters in DDP
97
+ if shading == "albedo":
98
+ return albedo + textureless_color * 0
99
+ elif shading == "textureless":
100
+ return albedo * 0 + textureless_color
101
+ elif shading == "diffuse":
102
+ return color
103
+ else:
104
+ raise ValueError(f"Unknown shading type {shading}")
105
+
106
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
107
+ if global_step < self.cfg.ambient_only_steps:
108
+ self.ambient_only = True
109
+ else:
110
+ self.ambient_only = False
111
+
112
+ def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]:
113
+ albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]).clamp(
114
+ 0.0, 1.0
115
+ )
116
+ return {"albedo": albedo}
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/diff_gaussian_rasterizer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import threestudio
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from diff_gaussian_rasterization import (
9
+ GaussianRasterizationSettings,
10
+ GaussianRasterizer,
11
+ )
12
+ from threestudio.models.background.base import BaseBackground
13
+ from threestudio.models.geometry.base import BaseGeometry
14
+ from threestudio.models.materials.base import BaseMaterial
15
+ from threestudio.models.renderers.base import Rasterizer
16
+ from threestudio.utils.typing import *
17
+
18
+ from .gaussian_batch_renderer import GaussianBatchRenderer
19
+
20
+
21
+ @threestudio.register("diff-gaussian-rasterizer")
22
+ class DiffGaussian(Rasterizer, GaussianBatchRenderer):
23
+ @dataclass
24
+ class Config(Rasterizer.Config):
25
+ debug: bool = False
26
+ invert_bg_prob: float = 1.0
27
+ back_ground_color: Tuple[float, float, float] = (1, 1, 1)
28
+
29
+ cfg: Config
30
+
31
+ def configure(
32
+ self,
33
+ geometry: BaseGeometry,
34
+ material: BaseMaterial,
35
+ background: BaseBackground,
36
+ ) -> None:
37
+ threestudio.info(
38
+ "[Note] Gaussian Splatting doesn't support material and background now."
39
+ )
40
+ super().configure(geometry, material, background)
41
+ self.background_tensor = torch.tensor(
42
+ self.cfg.back_ground_color, dtype=torch.float32, device="cuda"
43
+ )
44
+
45
+ def forward(
46
+ self,
47
+ viewpoint_camera,
48
+ bg_color: torch.Tensor,
49
+ scaling_modifier=1.0,
50
+ override_color=None,
51
+ **kwargs
52
+ ) -> Dict[str, Any]:
53
+ """
54
+ Render the scene.
55
+
56
+ Background tensor (bg_color) must be on GPU!
57
+ """
58
+
59
+ if self.training:
60
+ invert_bg_color = np.random.rand() > self.cfg.invert_bg_prob
61
+ else:
62
+ invert_bg_color = True
63
+
64
+ bg_color = bg_color if not invert_bg_color else (1.0 - bg_color)
65
+
66
+ pc = self.geometry
67
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
68
+ screenspace_points = (
69
+ torch.zeros_like(
70
+ pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
71
+ )
72
+ + 0
73
+ )
74
+ try:
75
+ screenspace_points.retain_grad()
76
+ except:
77
+ pass
78
+
79
+ # Set up rasterization configuration
80
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
81
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
82
+
83
+ raster_settings = GaussianRasterizationSettings(
84
+ image_height=int(viewpoint_camera.image_height),
85
+ image_width=int(viewpoint_camera.image_width),
86
+ tanfovx=tanfovx,
87
+ tanfovy=tanfovy,
88
+ bg=bg_color,
89
+ scale_modifier=scaling_modifier,
90
+ viewmatrix=viewpoint_camera.world_view_transform,
91
+ projmatrix=viewpoint_camera.full_proj_transform,
92
+ sh_degree=pc.active_sh_degree,
93
+ campos=viewpoint_camera.camera_center,
94
+ prefiltered=False,
95
+ debug=False,
96
+ include_feature=True
97
+ )
98
+
99
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
100
+
101
+ means3D = pc.get_xyz
102
+ means2D = screenspace_points
103
+ opacity = pc.get_opacity
104
+
105
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
106
+ # scaling / rotation by the rasterizer.
107
+ scales = None
108
+ rotations = None
109
+ cov3D_precomp = None
110
+ scales = pc.get_scaling
111
+ rotations = pc.get_rotation
112
+
113
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
114
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
115
+ shs = None
116
+ colors_precomp = None
117
+ if override_color is None:
118
+ shs = pc.get_features
119
+ else:
120
+ colors_precomp = override_color
121
+
122
+ language_feature_precomp = pc.get_language_feature
123
+ language_feature_precomp = language_feature_precomp/ (language_feature_precomp.norm(dim=-1, keepdim=True) + 1e-9)
124
+
125
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
126
+ result_list = rasterizer(
127
+ means3D=means3D,
128
+ means2D=means2D,
129
+ shs=shs,
130
+ colors_precomp=colors_precomp,
131
+ language_feature_precomp = language_feature_precomp,
132
+ opacities=opacity,
133
+ scales=scales,
134
+ rotations=rotations,
135
+ cov3D_precomp=cov3D_precomp,
136
+ )
137
+ rendered_image, rendered_feature, radii = result_list[0], result_list[1], result_list[2]
138
+
139
+ # Retain gradients of the 2D (screen-space) means for batch dim
140
+ if self.training:
141
+ screenspace_points.retain_grad()
142
+
143
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
144
+ # They will be excluded from value updates used in the splitting criteria.
145
+ return {
146
+ "render": rendered_image.clamp(0, 1),
147
+ "lang": rendered_feature,
148
+ "viewspace_points": screenspace_points,
149
+ "visibility_filter": radii > 0,
150
+ "radii": radii,
151
+ }
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/diff_gaussian_rasterizer_advanced.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import threestudio
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from diff_gaussian_rasterization import (
9
+ GaussianRasterizationSettings,
10
+ GaussianRasterizer,
11
+ )
12
+ from threestudio.models.background.base import BaseBackground
13
+ from threestudio.models.geometry.base import BaseGeometry
14
+ from threestudio.models.materials.base import BaseMaterial
15
+ from threestudio.models.renderers.base import Rasterizer
16
+ from threestudio.utils.typing import *
17
+
18
+ from .gaussian_batch_renderer import GaussianBatchRenderer
19
+
20
+
21
+ @threestudio.register("diff-gaussian-rasterizer-advanced")
22
+ class DiffGaussian(Rasterizer, GaussianBatchRenderer):
23
+ @dataclass
24
+ class Config(Rasterizer.Config):
25
+ debug: bool = False
26
+ invert_bg_prob: float = 1.0
27
+ back_ground_color: Tuple[float, float, float] = (1, 1, 1)
28
+
29
+ cfg: Config
30
+
31
+ def configure(
32
+ self,
33
+ geometry: BaseGeometry,
34
+ material: BaseMaterial,
35
+ background: BaseBackground,
36
+ ) -> None:
37
+ threestudio.info(
38
+ "[Note] Gaussian Splatting doesn't support material and background now."
39
+ )
40
+ super().configure(geometry, material, background)
41
+ self.background_tensor = torch.tensor(
42
+ self.cfg.back_ground_color, dtype=torch.float32, device="cuda"
43
+ )
44
+
45
+ def forward(
46
+ self,
47
+ viewpoint_camera,
48
+ bg_color: torch.Tensor,
49
+ scaling_modifier=1.0,
50
+ override_color=None,
51
+ **kwargs
52
+ ) -> Dict[str, Any]:
53
+ """
54
+ Render the scene.
55
+
56
+ Background tensor (bg_color) must be on GPU!
57
+ """
58
+
59
+ if self.training:
60
+ invert_bg_color = np.random.rand() > self.cfg.invert_bg_prob
61
+ else:
62
+ invert_bg_color = True
63
+
64
+ bg_color = bg_color if not invert_bg_color else (1.0 - bg_color)
65
+
66
+ pc = self.geometry
67
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
68
+ screenspace_points = (
69
+ torch.zeros_like(
70
+ pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
71
+ )
72
+ + 0
73
+ )
74
+ try:
75
+ screenspace_points.retain_grad()
76
+ except:
77
+ pass
78
+
79
+ # Set up rasterization configuration
80
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
81
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
82
+
83
+ raster_settings = GaussianRasterizationSettings(
84
+ image_height=int(viewpoint_camera.image_height),
85
+ image_width=int(viewpoint_camera.image_width),
86
+ tanfovx=tanfovx,
87
+ tanfovy=tanfovy,
88
+ bg=bg_color,
89
+ scale_modifier=scaling_modifier,
90
+ viewmatrix=viewpoint_camera.world_view_transform,
91
+ projmatrix=viewpoint_camera.full_proj_transform,
92
+ sh_degree=pc.active_sh_degree,
93
+ campos=viewpoint_camera.camera_center,
94
+ prefiltered=False,
95
+ debug=False,
96
+ include_feature=True,
97
+ )
98
+
99
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
100
+
101
+ means3D = pc.get_xyz
102
+ means2D = screenspace_points
103
+ opacity = pc.get_opacity
104
+
105
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
106
+ # scaling / rotation by the rasterizer.
107
+ scales = None
108
+ rotations = None
109
+ cov3D_precomp = None
110
+ scales = pc.get_scaling
111
+ rotations = pc.get_rotation
112
+
113
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
114
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
115
+ shs = None
116
+ colors_precomp = None
117
+ if override_color is None:
118
+ shs = pc.get_features
119
+ else:
120
+ colors_precomp = override_color
121
+
122
+ language_feature_precomp = pc.get_language_feature
123
+ language_feature_precomp = language_feature_precomp/ (language_feature_precomp.norm(dim=-1, keepdim=True) + 1e-9)
124
+
125
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
126
+ rendered_image, rendered_feature, radii, rendered_depth, rendered_alpha = rasterizer(
127
+ means3D=means3D,
128
+ means2D=means2D,
129
+ shs=shs,
130
+ colors_precomp=colors_precomp,
131
+ language_feature_precomp=language_feature_precomp,
132
+ opacities=opacity,
133
+ scales=scales,
134
+ rotations=rotations,
135
+ cov3D_precomp=cov3D_precomp,
136
+ )
137
+
138
+ # Retain gradients of the 2D (screen-space) means for batch dim
139
+ if self.training:
140
+ screenspace_points.retain_grad()
141
+
142
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
143
+ # They will be excluded from value updates used in the splitting criteria.
144
+ print(rendered_feature.mean())
145
+ return {
146
+ "render": rendered_image.clamp(0, 1),
147
+ "depth": rendered_depth,
148
+ "mask": rendered_alpha,
149
+ "viewspace_points": screenspace_points,
150
+ "visibility_filter": radii > 0,
151
+ "radii": radii,
152
+ }
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/diff_gaussian_rasterizer_background.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import threestudio
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from diff_gaussian_rasterization import (
9
+ GaussianRasterizationSettings,
10
+ GaussianRasterizer,
11
+ )
12
+ from threestudio.models.background.base import BaseBackground
13
+ from threestudio.models.geometry.base import BaseGeometry
14
+ from threestudio.models.materials.base import BaseMaterial
15
+ from threestudio.models.renderers.base import Rasterizer
16
+ from threestudio.utils.typing import *
17
+
18
+ from .gaussian_batch_renderer import GaussianBatchRenderer
19
+
20
+
21
+ @threestudio.register("diff-gaussian-rasterizer-background")
22
+ class DiffGaussian(Rasterizer, GaussianBatchRenderer):
23
+ @dataclass
24
+ class Config(Rasterizer.Config):
25
+ debug: bool = False
26
+ back_ground_color: Tuple[float, float, float] = (1, 1, 1)
27
+
28
+ cfg: Config
29
+
30
+ def configure(
31
+ self,
32
+ geometry: BaseGeometry,
33
+ material: BaseMaterial,
34
+ background: BaseBackground,
35
+ ) -> None:
36
+ threestudio.info(
37
+ "[Note] diff-gaussian-rasterizer-background doesn't support material."
38
+ )
39
+ super().configure(geometry, material, background)
40
+ self.background_tensor = torch.tensor(
41
+ self.cfg.back_ground_color, dtype=torch.float32, device="cuda"
42
+ )
43
+
44
+ def forward(
45
+ self,
46
+ viewpoint_camera,
47
+ bg_color: torch.Tensor,
48
+ scaling_modifier=1.0,
49
+ override_color=None,
50
+ **kwargs
51
+ ) -> Dict[str, Any]:
52
+ """
53
+ Render the scene.
54
+
55
+ Background tensor (bg_color) must be on GPU!
56
+ """
57
+ # use neural background
58
+ bg_color = bg_color * 0
59
+
60
+ pc = self.geometry
61
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
62
+ screenspace_points = (
63
+ torch.zeros_like(
64
+ pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
65
+ )
66
+ + 0
67
+ )
68
+ try:
69
+ screenspace_points.retain_grad()
70
+ except:
71
+ pass
72
+
73
+ # Set up rasterization configuration
74
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
75
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
76
+
77
+ raster_settings = GaussianRasterizationSettings(
78
+ image_height=int(viewpoint_camera.image_height),
79
+ image_width=int(viewpoint_camera.image_width),
80
+ tanfovx=tanfovx,
81
+ tanfovy=tanfovy,
82
+ bg=bg_color,
83
+ scale_modifier=scaling_modifier,
84
+ viewmatrix=viewpoint_camera.world_view_transform,
85
+ projmatrix=viewpoint_camera.full_proj_transform,
86
+ sh_degree=pc.active_sh_degree,
87
+ campos=viewpoint_camera.camera_center,
88
+ prefiltered=False,
89
+ debug=False,
90
+ )
91
+
92
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
93
+
94
+ means3D = pc.get_xyz
95
+ means2D = screenspace_points
96
+ opacity = pc.get_opacity
97
+
98
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
99
+ # scaling / rotation by the rasterizer.
100
+ scales = None
101
+ rotations = None
102
+ cov3D_precomp = None
103
+ scales = pc.get_scaling
104
+ rotations = pc.get_rotation
105
+
106
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
107
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
108
+ shs = None
109
+ colors_precomp = None
110
+ if override_color is None:
111
+ shs = pc.get_features
112
+ else:
113
+ colors_precomp = override_color
114
+
115
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
116
+ rays_d = kwargs["rays_d"][kwargs["batch_idx"]]
117
+ comp_rgb_bg = self.background(dirs=rays_d.unsqueeze(0))
118
+
119
+ rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
120
+ means3D=means3D,
121
+ means2D=means2D,
122
+ shs=shs,
123
+ colors_precomp=colors_precomp,
124
+ opacities=opacity,
125
+ scales=scales,
126
+ rotations=rotations,
127
+ cov3D_precomp=cov3D_precomp,
128
+ )
129
+ _, H, W = rendered_image.shape
130
+ rendered_image = rendered_image + (1 - rendered_alpha) * comp_rgb_bg.reshape(
131
+ H, W, 3
132
+ ).permute(2, 0, 1)
133
+
134
+ # Retain gradients of the 2D (screen-space) means for batch dim
135
+ if self.training:
136
+ screenspace_points.retain_grad()
137
+
138
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
139
+ # They will be excluded from value updates used in the splitting criteria.
140
+ return {
141
+ "render": rendered_image.clamp(0, 1),
142
+ "viewspace_points": screenspace_points,
143
+ "visibility_filter": radii > 0,
144
+ "radii": radii,
145
+ }
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/diff_gaussian_rasterizer_shading.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import threestudio
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from diff_gaussian_rasterization import (
9
+ GaussianRasterizationSettings,
10
+ GaussianRasterizer,
11
+ )
12
+ from threestudio.models.background.base import BaseBackground
13
+ from threestudio.models.geometry.base import BaseGeometry
14
+ from threestudio.models.materials.base import BaseMaterial
15
+ from threestudio.models.renderers.base import Rasterizer
16
+ from threestudio.utils.typing import *
17
+
18
+ from ..material.gaussian_material import GaussianDiffuseWithPointLightMaterial
19
+ from .gaussian_batch_renderer import GaussianBatchRenderer
20
+
21
+
22
+ class Depth2Normal(torch.nn.Module):
23
+ def __init__(self, *args, **kwargs) -> None:
24
+ super().__init__(*args, **kwargs)
25
+ self.delzdelxkernel = torch.tensor(
26
+ [
27
+ [0.00000, 0.00000, 0.00000],
28
+ [-1.00000, 0.00000, 1.00000],
29
+ [0.00000, 0.00000, 0.00000],
30
+ ]
31
+ )
32
+ self.delzdelykernel = torch.tensor(
33
+ [
34
+ [0.00000, -1.00000, 0.00000],
35
+ [0.00000, 0.00000, 0.00000],
36
+ [0.0000, 1.00000, 0.00000],
37
+ ]
38
+ )
39
+
40
+ def forward(self, x):
41
+ B, C, H, W = x.shape
42
+ delzdelxkernel = self.delzdelxkernel.view(1, 1, 3, 3).to(x.device)
43
+ delzdelx = F.conv2d(
44
+ x.reshape(B * C, 1, H, W), delzdelxkernel, padding=1
45
+ ).reshape(B, C, H, W)
46
+ delzdelykernel = self.delzdelykernel.view(1, 1, 3, 3).to(x.device)
47
+ delzdely = F.conv2d(
48
+ x.reshape(B * C, 1, H, W), delzdelykernel, padding=1
49
+ ).reshape(B, C, H, W)
50
+ normal = -torch.cross(delzdelx, delzdely, dim=1)
51
+ return normal
52
+
53
+
54
+ @threestudio.register("diff-gaussian-rasterizer-shading")
55
+ class DiffGaussian(Rasterizer, GaussianBatchRenderer):
56
+ @dataclass
57
+ class Config(Rasterizer.Config):
58
+ debug: bool = False
59
+ back_ground_color: Tuple[float, float, float] = (1, 1, 1)
60
+
61
+ cfg: Config
62
+
63
+ def configure(
64
+ self,
65
+ geometry: BaseGeometry,
66
+ material: BaseMaterial,
67
+ background: BaseBackground,
68
+ ) -> None:
69
+ if not isinstance(material, GaussianDiffuseWithPointLightMaterial):
70
+ raise NotImplementedError(
71
+ "diff-gaussian-rasterizer-shading only support Gaussian material."
72
+ )
73
+ super().configure(geometry, material, background)
74
+ self.normal_module = Depth2Normal()
75
+ self.background_tensor = torch.tensor(
76
+ self.cfg.back_ground_color, dtype=torch.float32, device="cuda"
77
+ )
78
+
79
+ def forward(
80
+ self,
81
+ viewpoint_camera,
82
+ bg_color: torch.Tensor,
83
+ scaling_modifier=1.0,
84
+ override_color=None,
85
+ **kwargs
86
+ ) -> Dict[str, Any]:
87
+ """
88
+ Render the scene.
89
+
90
+ Background tensor (bg_color) must be on GPU!
91
+ """
92
+ # use neural background
93
+ bg_color = bg_color * 0
94
+
95
+ pc = self.geometry
96
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
97
+ screenspace_points = (
98
+ torch.zeros_like(
99
+ pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
100
+ )
101
+ + 0
102
+ )
103
+ try:
104
+ screenspace_points.retain_grad()
105
+ except:
106
+ pass
107
+
108
+ # Set up rasterization configuration
109
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
110
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
111
+
112
+ raster_settings = GaussianRasterizationSettings(
113
+ image_height=int(viewpoint_camera.image_height),
114
+ image_width=int(viewpoint_camera.image_width),
115
+ tanfovx=tanfovx,
116
+ tanfovy=tanfovy,
117
+ bg=bg_color,
118
+ scale_modifier=scaling_modifier,
119
+ viewmatrix=viewpoint_camera.world_view_transform,
120
+ projmatrix=viewpoint_camera.full_proj_transform,
121
+ sh_degree=pc.active_sh_degree,
122
+ campos=viewpoint_camera.camera_center,
123
+ prefiltered=False,
124
+ debug=False,
125
+ )
126
+
127
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
128
+
129
+ means3D = pc.get_xyz
130
+ means2D = screenspace_points
131
+ opacity = pc.get_opacity
132
+
133
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
134
+ # scaling / rotation by the rasterizer.
135
+ scales = None
136
+ rotations = None
137
+ cov3D_precomp = None
138
+ scales = pc.get_scaling
139
+ rotations = pc.get_rotation
140
+
141
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
142
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
143
+ shs = None
144
+ colors_precomp = None
145
+ if override_color is None:
146
+ shs = pc.get_features
147
+ else:
148
+ colors_precomp = override_color
149
+
150
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
151
+ batch_idx = kwargs["batch_idx"]
152
+ rays_d = kwargs["rays_d"][batch_idx]
153
+ rays_o = kwargs["rays_o"][batch_idx]
154
+ # rays_d_flatten: Float[Tensor, "Nr 3"] = rays_d.unsqueeze(0)
155
+
156
+ comp_rgb_bg = self.background(dirs=rays_d.unsqueeze(0))
157
+
158
+ rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
159
+ means3D=means3D,
160
+ means2D=means2D,
161
+ shs=shs,
162
+ colors_precomp=colors_precomp,
163
+ opacities=opacity,
164
+ scales=scales,
165
+ rotations=rotations,
166
+ cov3D_precomp=cov3D_precomp,
167
+ )
168
+ _, H, W = rendered_image.shape
169
+
170
+ xyz_map = rays_o + rendered_depth.permute(1, 2, 0) * rays_d
171
+ normal_map = self.normal_module(xyz_map.permute(2, 0, 1).unsqueeze(0))[0]
172
+ normal_map = F.normalize(normal_map, dim=0)
173
+ if pc.cfg.pred_normal:
174
+ pred_normal_map, _, _, _ = rasterizer(
175
+ means3D=means3D,
176
+ means2D=torch.zeros_like(means2D),
177
+ shs=pc.get_normal.unsqueeze(1),
178
+ colors_precomp=None,
179
+ opacities=opacity,
180
+ scales=scales,
181
+ rotations=rotations,
182
+ cov3D_precomp=cov3D_precomp,
183
+ )
184
+ else:
185
+ pred_normal_map = None
186
+
187
+ light_positions = kwargs["light_positions"][batch_idx, None, None, :].expand(
188
+ H, W, -1
189
+ )
190
+
191
+ if pred_normal_map is not None:
192
+ shading_normal = pred_normal_map.permute(1, 2, 0).detach() * 2 - 1
193
+ shading_normal = F.normalize(shading_normal, dim=2)
194
+ else:
195
+ shading_normal = normal_map.permute(1, 2, 0)
196
+ rgb_fg = self.material(
197
+ positions=xyz_map,
198
+ shading_normal=shading_normal,
199
+ albedo=(rendered_image / (rendered_alpha + 1e-6)).permute(1, 2, 0),
200
+ light_positions=light_positions,
201
+ ).permute(2, 0, 1)
202
+ rendered_image = rgb_fg * rendered_alpha + (
203
+ 1 - rendered_alpha
204
+ ) * comp_rgb_bg.reshape(H, W, 3).permute(2, 0, 1)
205
+ normal_map = normal_map * 0.5 * rendered_alpha + 0.5
206
+ mask = rendered_alpha > 0.99
207
+ normal_mask = mask.repeat(3, 1, 1)
208
+ normal_map[~normal_mask] = normal_map[~normal_mask].detach()
209
+ rendered_depth[~mask] = rendered_depth[~mask].detach()
210
+
211
+ # Retain gradients of the 2D (screen-space) means for batch dim
212
+ if self.training:
213
+ screenspace_points.retain_grad()
214
+
215
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
216
+ # They will be excluded from value updates used in the splitting criteria.
217
+ return {
218
+ "render": rendered_image.clamp(0, 1),
219
+ "normal": normal_map,
220
+ "pred_normal": pred_normal_map,
221
+ "mask": rendered_alpha,
222
+ "depth": rendered_depth,
223
+ "viewspace_points": screenspace_points,
224
+ "visibility_filter": radii > 0,
225
+ "radii": radii,
226
+ }
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/renderer/gaussian_batch_renderer.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from threestudio.utils.ops import get_cam_info_gaussian
3
+ from torch.cuda.amp import autocast
4
+
5
+ from ..geometry.gaussian_base import BasicPointCloud, Camera
6
+
7
+
8
+ class GaussianBatchRenderer:
9
+ def batch_forward(self, batch):
10
+ bs = batch["c2w"].shape[0]
11
+ renders = []
12
+ viewspace_points = []
13
+ visibility_filters = []
14
+ radiis = []
15
+ normals = []
16
+ pred_normals = []
17
+ depths = []
18
+ masks = []
19
+ langs = []
20
+ for batch_idx in range(bs):
21
+ batch["batch_idx"] = batch_idx
22
+ fovy = batch["fovy"][batch_idx]
23
+ w2c, proj, cam_p, cam_proj = get_cam_info_gaussian(
24
+ c2w=batch["c2w"][batch_idx], fovx=fovy, fovy=fovy, znear=0.1, zfar=100
25
+ )
26
+
27
+ viewpoint_cam = Camera(
28
+ FoVx=fovy,
29
+ FoVy=fovy,
30
+ image_width=batch["width"],
31
+ image_height=batch["height"],
32
+ world_view_transform=w2c,
33
+ full_proj_transform=proj,
34
+ camera_center=cam_p,
35
+ )
36
+
37
+ with autocast(enabled=False):
38
+ render_pkg = self.forward(
39
+ viewpoint_cam, self.background_tensor, **batch
40
+ )
41
+ renders.append(render_pkg["render"])
42
+ viewspace_points.append(render_pkg["viewspace_points"])
43
+ visibility_filters.append(render_pkg["visibility_filter"])
44
+ radiis.append(render_pkg["radii"])
45
+ if render_pkg.__contains__("normal"):
46
+ normals.append(render_pkg["normal"])
47
+ if (
48
+ render_pkg.__contains__("pred_normal")
49
+ and render_pkg["pred_normal"] is not None
50
+ ):
51
+ pred_normals.append(render_pkg["pred_normal"])
52
+ if render_pkg.__contains__("depth"):
53
+ depths.append(render_pkg["depth"])
54
+ if render_pkg.__contains__("mask"):
55
+ masks.append(render_pkg["mask"])
56
+ if render_pkg.__contains__("lang"):
57
+ langs.append(render_pkg["lang"])
58
+
59
+ outputs = {
60
+ "comp_rgb": torch.stack(renders, dim=0).permute(0, 2, 3, 1),
61
+ "lang": torch.stack(langs, dim=0).permute(0, 2, 3, 1),
62
+ "viewspace_points": viewspace_points,
63
+ "visibility_filter": visibility_filters,
64
+ "radii": radiis,
65
+ }
66
+ if len(normals) > 0:
67
+ outputs.update(
68
+ {
69
+ "comp_normal": torch.stack(normals, dim=0).permute(0, 2, 3, 1),
70
+ }
71
+ )
72
+ if len(pred_normals) > 0:
73
+ outputs.update(
74
+ {
75
+ "comp_pred_normal": torch.stack(pred_normals, dim=0).permute(
76
+ 0, 2, 3, 1
77
+ ),
78
+ }
79
+ )
80
+ if len(depths) > 0:
81
+ outputs.update(
82
+ {
83
+ "comp_depth": torch.stack(depths, dim=0).permute(0, 2, 3, 1),
84
+ }
85
+ )
86
+ if len(masks) > 0:
87
+ outputs.update(
88
+ {
89
+ "comp_mask": torch.stack(masks, dim=0).permute(0, 2, 3, 1),
90
+ }
91
+ )
92
+ return outputs
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/system/gaussian_mvdream.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+
4
+ import numpy as np
5
+ import threestudio
6
+ import torch
7
+ from threestudio.systems.base import BaseLift3DSystem
8
+ from threestudio.systems.utils import parse_optimizer, parse_scheduler
9
+ from threestudio.utils.loss import tv_loss
10
+ from threestudio.utils.typing import *
11
+
12
+ from ..geometry.gaussian_base import BasicPointCloud
13
+
14
+
15
+ @threestudio.register("gaussian-splatting-mvdream-system")
16
+ class MVDreamSystem(BaseLift3DSystem):
17
+ @dataclass
18
+ class Config(BaseLift3DSystem.Config):
19
+ visualize_samples: bool = False
20
+
21
+ cfg: Config
22
+
23
+ def configure(self) -> None:
24
+ # set up geometry, material, background, renderer
25
+ super().configure()
26
+ self.automatic_optimization = False
27
+
28
+ self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)
29
+ self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
30
+ self.cfg.prompt_processor
31
+ )
32
+ self.prompt_utils = self.prompt_processor()
33
+
34
+ def configure_optimizers(self):
35
+ optim = self.geometry.optimizer
36
+ if hasattr(self, "merged_optimizer"):
37
+ return [optim]
38
+ if hasattr(self.cfg.optimizer, "name"):
39
+ net_optim = parse_optimizer(self.cfg.optimizer, self)
40
+ optim = self.geometry.merge_optimizer(net_optim)
41
+ self.merged_optimizer = True
42
+ else:
43
+ self.merged_optimizer = False
44
+ return [optim]
45
+
46
+ def on_load_checkpoint(self, checkpoint):
47
+ num_pts = checkpoint["state_dict"]["geometry._xyz"].shape[0]
48
+ pcd = BasicPointCloud(
49
+ points=np.zeros((num_pts, 3)),
50
+ colors=np.zeros((num_pts, 3)),
51
+ normals=np.zeros((num_pts, 3)),
52
+ )
53
+ self.geometry.create_from_pcd(pcd, 10)
54
+ self.geometry.training_setup()
55
+ return
56
+
57
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
58
+ self.geometry.update_learning_rate(self.global_step)
59
+ outputs = self.renderer.batch_forward(batch)
60
+ return outputs
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ opt = self.optimizers()
64
+ out = self(batch)
65
+
66
+ visibility_filter = out["visibility_filter"]
67
+ radii = out["radii"]
68
+ guidance_inp = out["comp_rgb"]
69
+ viewspace_point_tensor = out["viewspace_points"]
70
+ guidance_out = self.guidance(
71
+ guidance_inp, self.prompt_utils, **batch, rgb_as_latents=False
72
+ )
73
+
74
+ loss_sds = 0.0
75
+ loss = 0.0
76
+
77
+ self.log(
78
+ "gauss_num",
79
+ int(self.geometry.get_xyz.shape[0]),
80
+ on_step=True,
81
+ on_epoch=True,
82
+ prog_bar=True,
83
+ logger=True,
84
+ )
85
+
86
+ for name, value in guidance_out.items():
87
+ self.log(f"train/{name}", value)
88
+ if name.startswith("loss_"):
89
+ loss_sds += value * self.C(
90
+ self.cfg.loss[name.replace("loss_", "lambda_")]
91
+ )
92
+
93
+ xyz_mean = None
94
+ if self.cfg.loss["lambda_position"] > 0.0:
95
+ xyz_mean = self.geometry.get_xyz.norm(dim=-1)
96
+ loss_position = xyz_mean.mean()
97
+ self.log(f"train/loss_position", loss_position)
98
+ loss += self.C(self.cfg.loss["lambda_position"]) * loss_position
99
+
100
+ if self.cfg.loss["lambda_opacity"] > 0.0:
101
+ scaling = self.geometry.get_scaling.norm(dim=-1)
102
+ loss_opacity = (
103
+ scaling.detach().unsqueeze(-1) * self.geometry.get_opacity
104
+ ).sum()
105
+ self.log(f"train/loss_opacity", loss_opacity)
106
+ loss += self.C(self.cfg.loss["lambda_opacity"]) * loss_opacity
107
+
108
+ if self.cfg.loss["lambda_sparsity"] > 0.0:
109
+ loss_sparsity = (out["comp_mask"] ** 2 + 0.01).sqrt().mean()
110
+ self.log("train/loss_sparsity", loss_sparsity)
111
+ loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)
112
+
113
+ if self.cfg.loss["lambda_scales"] > 0.0:
114
+ scale_sum = torch.sum(self.geometry.get_scaling)
115
+ self.log(f"train/scales", scale_sum)
116
+ loss += self.C(self.cfg.loss["lambda_scales"]) * scale_sum
117
+
118
+ if self.cfg.loss["lambda_tv_loss"] > 0.0:
119
+ loss_tv = self.C(self.cfg.loss["lambda_tv_loss"]) * tv_loss(
120
+ out["comp_rgb"].permute(0, 3, 1, 2)
121
+ )
122
+ self.log(f"train/loss_tv", loss_tv)
123
+ loss += loss_tv
124
+
125
+ if (
126
+ out.__contains__("comp_depth")
127
+ and self.cfg.loss["lambda_depth_tv_loss"] > 0.0
128
+ ):
129
+ loss_depth_tv = self.C(self.cfg.loss["lambda_depth_tv_loss"]) * (
130
+ tv_loss(out["comp_depth"].permute(0, 3, 1, 2))
131
+ )
132
+ self.log(f"train/loss_depth_tv", loss_depth_tv)
133
+ loss += loss_depth_tv
134
+
135
+ if out.__contains__("comp_pred_normal"):
136
+ loss_pred_normal = torch.nn.functional.mse_loss(
137
+ out["comp_pred_normal"], out["comp_normal"].detach()
138
+ )
139
+ loss += loss_pred_normal
140
+
141
+ for name, value in self.cfg.loss.items():
142
+ self.log(f"train_params/{name}", self.C(value))
143
+
144
+ loss_sds.backward(retain_graph=True)
145
+ iteration = self.global_step
146
+ self.geometry.update_states(
147
+ iteration,
148
+ visibility_filter,
149
+ radii,
150
+ viewspace_point_tensor,
151
+ )
152
+ if loss > 0:
153
+ loss.backward()
154
+ opt.step()
155
+ opt.zero_grad(set_to_none=True)
156
+
157
+ return {"loss": loss_sds}
158
+
159
+ def validation_step(self, batch, batch_idx):
160
+ out = self(batch)
161
+ # import pdb; pdb.set_trace()
162
+ self.save_image_grid(
163
+ f"it{self.global_step}-{batch['index'][0]}.png",
164
+ [
165
+ {
166
+ "type": "rgb",
167
+ "img": out["comp_rgb"][0],
168
+ "kwargs": {"data_format": "HWC"},
169
+ },
170
+ ]
171
+ + (
172
+ [
173
+ {
174
+ "type": "rgb",
175
+ "img": out["comp_normal"][0],
176
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
177
+ }
178
+ ]
179
+ if "comp_normal" in out
180
+ else []
181
+ )
182
+ + (
183
+ [
184
+ {
185
+ "type": "rgb",
186
+ "img": out["comp_pred_normal"][0],
187
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
188
+ }
189
+ ]
190
+ if "comp_pred_normal" in out
191
+ else []
192
+ ),
193
+ name="validation_step",
194
+ step=self.global_step,
195
+ )
196
+
197
+ def on_validation_epoch_end(self):
198
+ pass
199
+
200
+ def test_step(self, batch, batch_idx):
201
+ out = self(batch)
202
+ self.save_image_grid(
203
+ f"it{self.global_step}-test/{batch['index'][0]}.png",
204
+ [
205
+ {
206
+ "type": "rgb",
207
+ "img": out["comp_rgb"][0],
208
+ "kwargs": {"data_format": "HWC"},
209
+ },
210
+ ]
211
+ + (
212
+ [
213
+ {
214
+ "type": "rgb",
215
+ "img": out["comp_normal"][0],
216
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
217
+ }
218
+ ]
219
+ if "comp_normal" in out
220
+ else []
221
+ )
222
+ + (
223
+ [
224
+ {
225
+ "type": "rgb",
226
+ "img": out["comp_pred_normal"][0],
227
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
228
+ }
229
+ ]
230
+ if "comp_pred_normal" in out
231
+ else []
232
+ ),
233
+ name="test_step",
234
+ step=self.global_step,
235
+ )
236
+ if batch["index"][0] == 0:
237
+ save_path = self.get_save_path("point_cloud.ply")
238
+ self.geometry.save_ply(save_path)
239
+
240
+ def on_test_epoch_end(self):
241
+ self.save_img_sequence(
242
+ f"it{self.true_global_step}-test",
243
+ f"it{self.true_global_step}-test",
244
+ "(\d+)\.png",
245
+ save_format="mp4",
246
+ fps=30,
247
+ name="test",
248
+ step=self.true_global_step,
249
+ )
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/system/gaussian_splatting.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import threestudio
6
+ import torch
7
+ from threestudio.systems.base import BaseLift3DSystem
8
+ from threestudio.systems.utils import parse_optimizer, parse_scheduler
9
+ from threestudio.utils.loss import tv_loss
10
+ from threestudio.utils.ops import get_cam_info_gaussian
11
+ from threestudio.utils.typing import *
12
+ from torch.cuda.amp import autocast
13
+
14
+ from ..geometry.gaussian_base import BasicPointCloud, Camera
15
+
16
+
17
+ @threestudio.register("gaussian-splatting-system")
18
+ class GaussianSplatting(BaseLift3DSystem):
19
+ @dataclass
20
+ class Config(BaseLift3DSystem.Config):
21
+ visualize_samples: bool = False
22
+
23
+ cfg: Config
24
+
25
+ def configure(self) -> None:
26
+ # set up geometry, material, background, renderer
27
+ super().configure()
28
+ self.automatic_optimization = False
29
+
30
+ self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)
31
+ self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
32
+ self.cfg.prompt_processor
33
+ )
34
+ self.prompt_utils = self.prompt_processor()
35
+
36
+ def configure_optimizers(self):
37
+ optim = self.geometry.optimizer
38
+ if hasattr(self, "merged_optimizer"):
39
+ return [optim]
40
+ if hasattr(self.cfg.optimizer, "name"):
41
+ net_optim = parse_optimizer(self.cfg.optimizer, self)
42
+ optim = self.geometry.merge_optimizer(net_optim)
43
+ self.merged_optimizer = True
44
+ else:
45
+ self.merged_optimizer = False
46
+ return [optim]
47
+
48
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
49
+ self.geometry.update_learning_rate(self.global_step)
50
+ outputs = self.renderer.batch_forward(batch)
51
+ return outputs
52
+
53
+ def on_fit_start(self) -> None:
54
+ super().on_fit_start()
55
+
56
+ def training_step(self, batch, batch_idx):
57
+ opt = self.optimizers()
58
+ out = self(batch)
59
+
60
+ visibility_filter = out["visibility_filter"]
61
+ radii = out["radii"]
62
+ guidance_inp = out["comp_rgb"]
63
+ # import pdb; pdb.set_trace()
64
+ viewspace_point_tensor = out["viewspace_points"]
65
+ guidance_out = self.guidance(
66
+ guidance_inp, self.prompt_utils, **batch, rgb_as_latents=False
67
+ )
68
+
69
+ loss_sds = 0.0
70
+ loss = 0.0
71
+
72
+ self.log(
73
+ "gauss_num",
74
+ int(self.geometry.get_xyz.shape[0]),
75
+ on_step=True,
76
+ on_epoch=True,
77
+ prog_bar=True,
78
+ logger=True,
79
+ )
80
+
81
+ for name, value in guidance_out.items():
82
+ self.log(f"train/{name}", value)
83
+ if name.startswith("loss_"):
84
+ loss_sds += value * self.C(
85
+ self.cfg.loss[name.replace("loss_", "lambda_")]
86
+ )
87
+
88
+ xyz_mean = None
89
+ if self.cfg.loss["lambda_position"] > 0.0:
90
+ xyz_mean = self.geometry.get_xyz.norm(dim=-1)
91
+ loss_position = xyz_mean.mean()
92
+ self.log(f"train/loss_position", loss_position)
93
+ loss += self.C(self.cfg.loss["lambda_position"]) * loss_position
94
+
95
+ if self.cfg.loss["lambda_opacity"] > 0.0:
96
+ scaling = self.geometry.get_scaling.norm(dim=-1)
97
+ loss_opacity = (
98
+ scaling.detach().unsqueeze(-1) * self.geometry.get_opacity
99
+ ).sum()
100
+ self.log(f"train/loss_opacity", loss_opacity)
101
+ loss += self.C(self.cfg.loss["lambda_opacity"]) * loss_opacity
102
+
103
+ if self.cfg.loss["lambda_scales"] > 0.0:
104
+ scale_sum = torch.sum(self.geometry.get_scaling)
105
+ self.log(f"train/scales", scale_sum)
106
+ loss += self.C(self.cfg.loss["lambda_scales"]) * scale_sum
107
+
108
+ if self.cfg.loss["lambda_tv_loss"] > 0.0:
109
+ loss_tv = self.C(self.cfg.loss["lambda_tv_loss"]) * tv_loss(
110
+ out["comp_rgb"].permute(0, 3, 1, 2)
111
+ )
112
+ self.log(f"train/loss_tv", loss_tv)
113
+ loss += loss_tv
114
+
115
+ if (
116
+ out.__contains__("comp_depth")
117
+ and self.cfg.loss["lambda_depth_tv_loss"] > 0.0
118
+ ):
119
+ loss_depth_tv = self.C(self.cfg.loss["lambda_depth_tv_loss"]) * (
120
+ tv_loss(out["comp_normal"].permute(0, 3, 1, 2))
121
+ + tv_loss(out["comp_depth"].permute(0, 3, 1, 2))
122
+ )
123
+ self.log(f"train/loss_depth_tv", loss_depth_tv)
124
+ loss += loss_depth_tv
125
+
126
+ for name, value in self.cfg.loss.items():
127
+ self.log(f"train_params/{name}", self.C(value))
128
+
129
+ loss_sds.backward(retain_graph=True)
130
+ iteration = self.global_step
131
+ self.geometry.update_states(
132
+ iteration,
133
+ visibility_filter,
134
+ radii,
135
+ viewspace_point_tensor,
136
+ )
137
+ if loss > 0:
138
+ loss.backward()
139
+ opt.step()
140
+ opt.zero_grad(set_to_none=True)
141
+
142
+ return {"loss": loss_sds}
143
+
144
+ def validation_step(self, batch, batch_idx):
145
+ out = self(batch)
146
+ # import pdb; pdb.set_trace()
147
+ self.save_image_grid(
148
+ f"it{self.global_step}-{batch['index'][0]}.png",
149
+ [
150
+ {
151
+ "type": "rgb",
152
+ "img": out["comp_rgb"][0],
153
+ "kwargs": {"data_format": "HWC"},
154
+ },
155
+ ]
156
+ + (
157
+ [
158
+ {
159
+ "type": "rgb",
160
+ "img": out["comp_normal"][0],
161
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
162
+ }
163
+ ]
164
+ if "comp_normal" in out
165
+ else []
166
+ ),
167
+ name="validation_step",
168
+ step=self.global_step,
169
+ )
170
+
171
+ def on_validation_epoch_end(self):
172
+ pass
173
+
174
+ def test_step(self, batch, batch_idx):
175
+ out = self(batch)
176
+ self.save_image_grid(
177
+ f"it{self.global_step}-test/{batch['index'][0]}.png",
178
+ [
179
+ {
180
+ "type": "rgb",
181
+ "img": out["comp_rgb"][0],
182
+ "kwargs": {"data_format": "HWC"},
183
+ },
184
+ ]
185
+ + (
186
+ [
187
+ {
188
+ "type": "rgb",
189
+ "img": out["comp_normal"][0],
190
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
191
+ }
192
+ ]
193
+ if "comp_normal" in out
194
+ else []
195
+ ),
196
+ name="test_step",
197
+ step=self.global_step,
198
+ )
199
+ if batch["index"][0] == 0:
200
+ save_path = self.get_save_path("point_cloud.ply")
201
+ self.geometry.save_ply(save_path)
202
+
203
+ def on_test_epoch_end(self):
204
+ self.save_img_sequence(
205
+ f"it{self.global_step}-test",
206
+ f"it{self.global_step}-test",
207
+ "(\d+)\.png",
208
+ save_format="mp4",
209
+ fps=30,
210
+ name="test",
211
+ step=self.global_step,
212
+ )
213
+
214
+ def on_load_checkpoint(self, ckpt_dict) -> None:
215
+ num_pts = ckpt_dict["state_dict"]["geometry._xyz"].shape[0]
216
+ pcd = BasicPointCloud(
217
+ points=np.zeros((num_pts, 3)),
218
+ colors=np.zeros((num_pts, 3)),
219
+ normals=np.zeros((num_pts, 3)),
220
+ )
221
+ self.geometry.create_from_pcd(pcd, 10)
222
+ self.geometry.training_setup()
223
+ super().on_load_checkpoint(ckpt_dict)
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/system/gaussian_zero123.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from dataclasses import dataclass, field
4
+
5
+ import numpy as np
6
+ import threestudio
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from threestudio.systems.base import BaseLift3DSystem
10
+ from threestudio.systems.utils import parse_optimizer, parse_scheduler
11
+ from threestudio.utils.loss import tv_loss
12
+ from threestudio.utils.ops import get_cam_info_gaussian
13
+ from threestudio.utils.typing import *
14
+ from torch.cuda.amp import autocast
15
+ from torchmetrics import PearsonCorrCoef
16
+
17
+ from ..geometry.gaussian_base import BasicPointCloud, Camera
18
+
19
+
20
+ @threestudio.register("gaussian-splatting-zero123-system")
21
+ class Zero123(BaseLift3DSystem):
22
+ @dataclass
23
+ class Config(BaseLift3DSystem.Config):
24
+ freq: dict = field(default_factory=dict)
25
+ refinement: bool = False
26
+ ambient_ratio_min: float = 0.5
27
+ back_ground_color: Tuple[float, float, float] = (1, 1, 1)
28
+
29
+ cfg: Config
30
+
31
+ def configure(self):
32
+ # create geometry, material, background, renderer
33
+ super().configure()
34
+ self.automatic_optimization = False
35
+
36
+ def configure_optimizers(self):
37
+ optim = self.geometry.optimizer
38
+ if hasattr(self, "merged_optimizer"):
39
+ return [optim]
40
+ if hasattr(self.cfg.optimizer, "name"):
41
+ net_optim = parse_optimizer(self.cfg.optimizer, self)
42
+ optim = self.geometry.merge_optimizer(net_optim)
43
+ self.merged_optimizer = True
44
+ else:
45
+ self.merged_optimizer = False
46
+ return [optim]
47
+
48
+ def on_load_checkpoint(self, checkpoint):
49
+ num_pts = checkpoint["state_dict"]["geometry._xyz"].shape[0]
50
+ pcd = BasicPointCloud(
51
+ points=np.zeros((num_pts, 3)),
52
+ colors=np.zeros((num_pts, 3)),
53
+ normals=np.zeros((num_pts, 3)),
54
+ )
55
+ self.geometry.create_from_pcd(pcd, 10)
56
+ self.geometry.training_setup()
57
+ return
58
+
59
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
60
+ self.geometry.update_learning_rate(self.global_step)
61
+ outputs = self.renderer.batch_forward(batch)
62
+ return outputs
63
+
64
+ def on_fit_start(self) -> None:
65
+ super().on_fit_start()
66
+ # no prompt processor
67
+ self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)
68
+
69
+ # visualize all training images
70
+ all_images = self.trainer.datamodule.train_dataloader().dataset.get_all_images()
71
+ self.save_image_grid(
72
+ "all_training_images.png",
73
+ [
74
+ {"type": "rgb", "img": image, "kwargs": {"data_format": "HWC"}}
75
+ for image in all_images
76
+ ],
77
+ name="on_fit_start",
78
+ step=self.true_global_step,
79
+ )
80
+
81
+ self.pearson = PearsonCorrCoef().to(self.device)
82
+
83
+ def training_substep(self, batch, batch_idx, guidance: str):
84
+ """
85
+ Args:
86
+ guidance: one of "ref" (reference image supervision), "zero123"
87
+ """
88
+ if guidance == "ref":
89
+ ambient_ratio = 1.0
90
+ shading = "diffuse"
91
+ batch["shading"] = shading
92
+ elif guidance == "zero123":
93
+ batch = batch["random_camera"]
94
+ ambient_ratio = (
95
+ self.cfg.ambient_ratio_min
96
+ + (1 - self.cfg.ambient_ratio_min) * random.random()
97
+ )
98
+
99
+ batch["ambient_ratio"] = ambient_ratio
100
+
101
+ out = self(batch)
102
+ loss_prefix = f"loss_{guidance}_"
103
+
104
+ loss_terms = {}
105
+
106
+ def set_loss(name, value):
107
+ loss_terms[f"{loss_prefix}{name}"] = value
108
+
109
+ guidance_eval = (
110
+ guidance == "zero123"
111
+ and self.cfg.freq.guidance_eval > 0
112
+ and self.true_global_step % self.cfg.freq.guidance_eval == 0
113
+ )
114
+
115
+ if guidance == "ref":
116
+ gt_mask = batch["mask"]
117
+ gt_rgb = batch["rgb"]
118
+
119
+ # color loss
120
+ gt_rgb = gt_rgb * gt_mask.float()
121
+ set_loss("rgb", F.mse_loss(gt_rgb, out["comp_rgb"] * gt_mask.float()))
122
+
123
+ # mask loss
124
+ set_loss("mask", F.mse_loss(gt_mask.float(), out["comp_mask"]))
125
+
126
+ # depth loss
127
+ if self.C(self.cfg.loss.lambda_depth) > 0:
128
+ valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)].unsqueeze(1)
129
+ valid_pred_depth = out["comp_depth"][gt_mask].unsqueeze(1)
130
+ with torch.no_grad():
131
+ A = torch.cat(
132
+ [valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1
133
+ ) # [B, 2]
134
+ X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1]
135
+ valid_gt_depth = A @ X # [B, 1]
136
+ set_loss("depth", F.mse_loss(valid_gt_depth, valid_pred_depth))
137
+
138
+ # relative depth loss
139
+ if self.C(self.cfg.loss.lambda_depth_rel) > 0:
140
+ valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)] # [B,]
141
+ valid_pred_depth = out["comp_depth"][gt_mask] # [B,]
142
+ set_loss(
143
+ "depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth)
144
+ )
145
+
146
+ # normal loss
147
+ if self.C(self.cfg.loss.lambda_normal) > 0:
148
+ valid_gt_normal = (
149
+ 1 - 2 * batch["ref_normal"][gt_mask.squeeze(-1)]
150
+ ) # [B, 3]
151
+ valid_pred_normal = (
152
+ 2 * out["comp_normal"][gt_mask.squeeze(-1)] - 1
153
+ ) # [B, 3]
154
+ set_loss(
155
+ "normal",
156
+ 1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean(),
157
+ )
158
+ elif guidance == "zero123":
159
+ # zero123
160
+ guidance_out = self.guidance(
161
+ out["comp_rgb"],
162
+ **batch,
163
+ rgb_as_latents=False,
164
+ guidance_eval=guidance_eval,
165
+ )
166
+ # claforte: TODO: rename the loss_terms keys
167
+ set_loss("sds", guidance_out["loss_sds"])
168
+
169
+ if self.C(self.cfg.loss.lambda_normal_smooth) > 0:
170
+ if "comp_normal" not in out:
171
+ raise ValueError(
172
+ "comp_normal is required for 2D normal smooth loss, no comp_normal is found in the output."
173
+ )
174
+ normal = out["comp_normal"]
175
+ set_loss(
176
+ "normal_smooth",
177
+ (normal[:, 1:, :, :] - normal[:, :-1, :, :]).square().mean()
178
+ + (normal[:, :, 1:, :] - normal[:, :, :-1, :]).square().mean(),
179
+ )
180
+
181
+ loss = 0.0
182
+ for name, value in loss_terms.items():
183
+ self.log(f"train/{name}", value)
184
+ if name.startswith(loss_prefix):
185
+ loss_weighted = value * self.C(
186
+ self.cfg.loss[name.replace(loss_prefix, "lambda_")]
187
+ )
188
+ self.log(f"train/{name}_w", loss_weighted)
189
+ loss += loss_weighted
190
+
191
+ for name, value in self.cfg.loss.items():
192
+ self.log(f"train_params/{name}", self.C(value))
193
+
194
+ self.log(f"train/loss_{guidance}", loss)
195
+
196
+ out.update({"loss": loss})
197
+ return out
198
+
199
+ def training_step(self, batch, batch_idx):
200
+ opt = self.optimizers()
201
+
202
+ if self.cfg.freq.get("ref_or_zero123", "accumulate") == "accumulate":
203
+ do_ref = True
204
+ do_zero123 = True
205
+ elif self.cfg.freq.get("ref_or_zero123", "accumulate") == "alternate":
206
+ do_ref = (
207
+ self.true_global_step < self.cfg.freq.ref_only_steps
208
+ or self.true_global_step % self.cfg.freq.n_ref == 0
209
+ )
210
+ do_zero123 = not do_ref
211
+
212
+ total_loss = 0.0
213
+ if do_zero123:
214
+ out = self.training_substep(batch, batch_idx, guidance="zero123")
215
+ total_loss += out["loss"]
216
+
217
+ if do_ref:
218
+ out = self.training_substep(batch, batch_idx, guidance="ref")
219
+ total_loss += out["loss"]
220
+
221
+ self.log("train/loss", total_loss, prog_bar=True)
222
+
223
+ visibility_filter = out["visibility_filter"]
224
+ radii = out["radii"]
225
+ guidance_inp = out["comp_rgb"]
226
+ viewspace_point_tensor = out["viewspace_points"]
227
+
228
+ total_loss.backward()
229
+ iteration = self.global_step
230
+ self.geometry.update_states(
231
+ iteration,
232
+ visibility_filter,
233
+ radii,
234
+ viewspace_point_tensor,
235
+ )
236
+ opt.step()
237
+ opt.zero_grad(set_to_none=True)
238
+
239
+ return {"loss": total_loss}
240
+
241
+ def validation_step(self, batch, batch_idx):
242
+ out = self(batch)
243
+ self.save_image_grid(
244
+ f"it{self.true_global_step}-val/{batch['index'][0]}.png",
245
+ (
246
+ [
247
+ {
248
+ "type": "rgb",
249
+ "img": batch["rgb"][0],
250
+ "kwargs": {"data_format": "HWC"},
251
+ }
252
+ ]
253
+ if "rgb" in batch
254
+ else []
255
+ )
256
+ + [
257
+ {
258
+ "type": "rgb",
259
+ "img": out["comp_rgb"][0],
260
+ "kwargs": {"data_format": "HWC"},
261
+ },
262
+ ]
263
+ + (
264
+ [
265
+ {
266
+ "type": "rgb",
267
+ "img": out["comp_normal"][0],
268
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
269
+ }
270
+ ]
271
+ if "comp_normal" in out
272
+ else []
273
+ ),
274
+ # claforte: TODO: don't hardcode the frame numbers to record... read them from cfg instead.
275
+ name=f"validation_step_batchidx_{batch_idx}"
276
+ if batch_idx in [0, 7, 15, 23, 29]
277
+ else None,
278
+ step=self.true_global_step,
279
+ )
280
+
281
+ def on_validation_epoch_end(self):
282
+ filestem = f"it{self.true_global_step}-val"
283
+ self.save_img_sequence(
284
+ filestem,
285
+ filestem,
286
+ "(\d+)\.png",
287
+ save_format="mp4",
288
+ fps=30,
289
+ name="validation_epoch_end",
290
+ step=self.true_global_step,
291
+ )
292
+
293
+ def test_step(self, batch, batch_idx):
294
+ out = self(batch)
295
+ self.save_image_grid(
296
+ f"it{self.true_global_step}-test/{batch['index'][0]}.png",
297
+ (
298
+ [
299
+ {
300
+ "type": "rgb",
301
+ "img": batch["rgb"][0],
302
+ "kwargs": {"data_format": "HWC"},
303
+ }
304
+ ]
305
+ if "rgb" in batch
306
+ else []
307
+ )
308
+ + [
309
+ {
310
+ "type": "rgb",
311
+ "img": out["comp_rgb"][0],
312
+ "kwargs": {"data_format": "HWC"},
313
+ },
314
+ ]
315
+ + (
316
+ [
317
+ {
318
+ "type": "rgb",
319
+ "img": out["comp_normal"][0],
320
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
321
+ }
322
+ ]
323
+ if "comp_normal" in out
324
+ else []
325
+ ),
326
+ name="test_step",
327
+ step=self.true_global_step,
328
+ )
329
+
330
+ def on_test_epoch_end(self):
331
+ self.save_img_sequence(
332
+ f"it{self.true_global_step}-test",
333
+ f"it{self.true_global_step}-test",
334
+ "(\d+)\.png",
335
+ save_format="mp4",
336
+ fps=30,
337
+ name="test",
338
+ step=self.true_global_step,
339
+ )
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/system/scene_lang.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass, field
3
+
4
+ import os
5
+ import collections
6
+ import random
7
+ import numpy as np
8
+ import threestudio
9
+ import torch
10
+ import cv2
11
+ from sklearn.cluster import KMeans
12
+ import torchvision
13
+ from PIL import Image
14
+ from transformers import pipeline
15
+ from threestudio.systems.base import BaseLift3DSystem
16
+ from threestudio.systems.utils import parse_optimizer, parse_scheduler
17
+ from threestudio.utils.loss import tv_loss
18
+ from threestudio.utils.ops import get_cam_info_gaussian
19
+ from threestudio.utils.typing import *
20
+ from torch.cuda.amp import autocast
21
+ from tqdm.contrib import tenumerate
22
+ from tqdm import tqdm, trange
23
+
24
+ from ..geometry.gaussian_base import BasicPointCloud, Camera
25
+ from ..utils.sam_clip import SamClip
26
+ from ..utils.ae import Autoencoder_dataset, Autoencoder
27
+ from torch.utils.data import Dataset, DataLoader
28
+
29
+ def l2_loss(network_output, gt):
30
+ return ((network_output - gt) ** 2).mean()
31
+
32
+ def cos_loss(network_output, gt):
33
+ return 1 - torch.nn.functional.cosine_similarity(network_output, gt, dim=0).mean()
34
+
35
+
36
+ @threestudio.register("scene-lang-system")
37
+ class SceneLang(BaseLift3DSystem):
38
+ @dataclass
39
+ class Config(BaseLift3DSystem.Config):
40
+ visualize_samples: bool = False
41
+
42
+ distill_lang_freq: int = 800
43
+ outpaint_step: int = 300
44
+ sam_clip: dict = field(default_factory=dict)
45
+ encoder_hidden_dims: Optional[List] = field(default_factory=list)
46
+ decoder_hidden_dims: Optional[List] = field(default_factory=list)
47
+ ae_epoch: int = 100
48
+ distill_lang_epoch: int = 100
49
+ sam_clip_ae_lr: float = 3e-4
50
+ densify: bool = True
51
+ distill_interval: int = 2
52
+ xyz_noise_ratio: Any = None
53
+ drop_ooi_ratio: Any = field(default_factory=dict)
54
+ empty_prompt: str = "empty"
55
+ side_prompt: str = "empty"
56
+ crop_with_lang: bool = True
57
+ rotate_aug_scale: int = 15
58
+
59
+ cfg: Config
60
+
61
+ def configure(self) -> None:
62
+ # set up geometry, material, background, renderer
63
+ super().configure()
64
+ self.automatic_optimization = False
65
+
66
+ self.geometry.prompt = self.cfg.prompt_processor.prompt
67
+ self.geometry.empty_prompt = self.cfg.empty_prompt
68
+ self.geometry.side_prompt = self.cfg.side_prompt
69
+
70
+ self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)
71
+ self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
72
+ self.cfg.prompt_processor
73
+ )
74
+ self.prompt_utils = self.prompt_processor()
75
+
76
+ self.cfg.prompt_processor.prompt = self.cfg.empty_prompt
77
+ self.bg_prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
78
+ self.cfg.prompt_processor
79
+ )
80
+ self.bg_prompt_utils = self.bg_prompt_processor()
81
+
82
+ self.sam_clip = SamClip(self.cfg.sam_clip)
83
+ self.sam_clip_ae = Autoencoder(self.cfg.encoder_hidden_dims, self.cfg.decoder_hidden_dims).cuda()
84
+
85
+ def configure_optimizers(self):
86
+ optim = self.geometry.optimizer
87
+ if hasattr(self, "merged_optimizer"):
88
+ return [optim]
89
+ if hasattr(self.cfg.optimizer, "name"):
90
+ net_optim = parse_optimizer(self.cfg.optimizer, self)
91
+ optim = self.geometry.merge_optimizer(net_optim)
92
+ self.merged_optimizer = True
93
+ else:
94
+ self.merged_optimizer = False
95
+ return [optim]
96
+
97
+ def on_save_checkpoint(self, checkpoint):
98
+ if 'optimizer_states' in checkpoint.keys():
99
+ del checkpoint['optimizer_states']
100
+
101
+ del_keys = [k for k in checkpoint['state_dict'].keys() if 'sam' in k]
102
+ for k in del_keys:
103
+ del checkpoint['state_dict'][k]
104
+
105
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
106
+ self.geometry.update_learning_rate(self.global_step)
107
+ outputs = self.renderer.batch_forward(batch)
108
+ return outputs
109
+
110
+ def on_fit_start(self) -> None:
111
+ super().on_fit_start()
112
+
113
+ def training_step(self, batch, batch_idx):
114
+ self.geometry.noise_ratio = self.C(self.cfg.xyz_noise_ratio)
115
+ if random.random() < self.C(self.cfg.drop_ooi_ratio):
116
+ self.geometry._opacity_mask = (sum(self.geometry.ooi_masks)==0).float()
117
+ else:
118
+ self.geometry._opacity_mask = None
119
+
120
+ if self.true_global_step > 0 and self.true_global_step == self.cfg.distill_lang_freq : # finish rgb phase
121
+ self.distill_language_feature()
122
+
123
+ if self.true_global_step == self.cfg.outpaint_step:
124
+ self.outpaint()
125
+
126
+ apply_rotate = False
127
+ if self.true_global_step > self.cfg.distill_lang_freq:
128
+ apply_rotate = random.random() < 0.5
129
+ self.geometry.random_rotate(self.cfg.rotate_aug_scale, apply_rotate)
130
+
131
+ opt = self.optimizers()
132
+ out = self(batch)
133
+
134
+ visibility_filter = out["visibility_filter"]
135
+ radii = out["radii"]
136
+ guidance_inp = out["comp_rgb"]
137
+ viewspace_point_tensor = out["viewspace_points"]
138
+ if self.geometry._opacity_mask is None:
139
+ pu = self.prompt_utils
140
+ else:
141
+ pu = self.bg_prompt_utils
142
+ guidance_out = self.guidance(
143
+ guidance_inp, pu, **batch, rgb_as_latents=False
144
+ )
145
+
146
+ loss_sds = 0.0
147
+ loss = 0.0
148
+
149
+ self.log(
150
+ "gauss_num",
151
+ int(self.geometry.get_xyz.shape[0]),
152
+ on_step=True,
153
+ on_epoch=True,
154
+ prog_bar=True,
155
+ logger=True,
156
+ )
157
+
158
+ if self.cfg.loss["lambda_ref"] > 0.0:
159
+ ref_img = self.cfg.geometry.geometry_convert_from[len("depth:"):]
160
+ ref_img = torch.tensor(np.array(Image.open(ref_img).resize((self.dataset.cfg.width, self.dataset.cfg.height)))[None] / 255, device = out['comp_rgb'].device)
161
+ bg_ref_img = torch.tensor(self.geometry.bg_image[None] / 255, device = out['comp_rgb'].device)
162
+ bg_ref_img_mask = torch.from_numpy(self.geometry.bg_image_mask[None, ..., None].astype(float)).cuda()
163
+
164
+ if self.geometry._opacity_mask is None:
165
+ if not apply_rotate:
166
+ l1loss = torch.nn.L1Loss()(out['comp_rgb'][0:1], ref_img) # only calculate the first view (zero view)
167
+ self.log(f"train/recon_front_view", l1loss)
168
+ loss += l1loss * self.cfg.loss["lambda_ref"]
169
+
170
+ if self.true_global_step > self.cfg.outpaint_step:
171
+ for view_idx in [0, -1]:
172
+ self.geometry._opacity_mask = None
173
+ sample = self.trainer.val_dataloaders.dataset[view_idx]
174
+ for k in sample.keys():
175
+ try:
176
+ sample[k] = sample[k].cuda()[None]
177
+ except:
178
+ pass
179
+ output = self(sample)
180
+ rgb = output['comp_rgb']
181
+ target = self.outpaint_view[view_idx]
182
+ # loss += torch.nn.L1Loss()(rgb, target) * self.cfg.loss["lambda_ref"]
183
+ loss += (torch.nn.L1Loss(reduction='none')(rgb, target) * self.outpaint_mask[view_idx]).mean() * self.cfg.loss["lambda_ref"]
184
+ else:
185
+ ratio = bg_ref_img_mask.sum() / bg_ref_img_mask.shape[1] / bg_ref_img_mask.shape[2]
186
+ l1loss = torch.nn.L1Loss(reduction='none')(out['comp_rgb'][0:1], bg_ref_img) * bg_ref_img_mask # only calculate the first view (zero view)
187
+ l1loss = l1loss.mean() / ratio
188
+ loss += l1loss * self.cfg.loss["lambda_ref"]
189
+
190
+ if self.cfg.loss["lambda_scaling"] > 0.0:
191
+ scaling_loss = self.geometry.get_scaling.mean()
192
+ loss += scaling_loss * self.cfg.loss["lambda_scaling"]
193
+
194
+ for name, value in guidance_out.items():
195
+ self.log(f"train/{name}", value)
196
+ if name.startswith("loss_"):
197
+ loss_sds += value * self.C(
198
+ self.cfg.loss[name.replace("loss_", "lambda_")]
199
+ )
200
+
201
+ loss = loss + loss_sds
202
+ iteration = self.global_step
203
+ opt.zero_grad()
204
+ if loss > 0:
205
+ loss.backward(retain_graph=True)
206
+ if self.cfg.densify:
207
+ self.geometry.update_states(
208
+ iteration,
209
+ visibility_filter,
210
+ radii,
211
+ viewspace_point_tensor,
212
+ )
213
+ opt.step()
214
+ opt.zero_grad(set_to_none=True)
215
+
216
+ self.log("train/loss", loss)
217
+ return {"loss": loss}
218
+
219
+ def validation_step(self, batch, batch_idx):
220
+ self.geometry._opacity_mask = None
221
+ out = self(batch)
222
+ mask, _ = self.geometry.project_pc(batch['c2w'], H=self.dataset.cfg.height, W=self.dataset.cfg.width)
223
+ self.save_image_grid(
224
+ f"it{self.global_step}-val/{batch['index'][0]}.png",
225
+ [
226
+ {
227
+ "type": "rgb",
228
+ "img": out["comp_rgb"][0],
229
+ "kwargs": {"data_format": "HWC"},
230
+ },
231
+ ]
232
+ + (
233
+ [
234
+ {
235
+ "type": "rgb",
236
+ "img": out["comp_normal"][0],
237
+ "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
238
+ }
239
+ ]
240
+ if "comp_normal" in out
241
+ else []
242
+ ),
243
+ name="validation_step",
244
+ step=self.global_step,
245
+ )
246
+
247
+ def on_validation_epoch_end(self):
248
+ self.save_img_sequence(
249
+ f"it{self.global_step}-val",
250
+ f"it{self.global_step}-val",
251
+ "(\d+)\.png",
252
+ save_format="mp4",
253
+ fps=30,
254
+ name="val",
255
+ step=self.global_step,
256
+ delete_images=True,
257
+ )
258
+
259
+ def test_step(self, batch, batch_idx):
260
+ # remove the random rotation effect!
261
+ self.geometry.recover_xyzrot()
262
+ out = self(batch)
263
+ self.save_image_grid(
264
+ f"it{self.global_step}-test/{batch['index'][0]}.png",
265
+ [
266
+ {
267
+ "type": "rgb",
268
+ "img": out["comp_rgb"][0],
269
+ "kwargs": {"data_format": "HWC"},
270
+ },
271
+ ]
272
+ + [
273
+ {
274
+ "type": "rgb",
275
+ "img": out["lang"][0],
276
+ "kwargs": {"data_format": "HWC", "data_range": (out["lang"][0].min().item(), out["lang"][0].max().item())},
277
+ },
278
+ ],
279
+ name="test_step",
280
+ step=self.global_step,
281
+ )
282
+ if batch["index"][0] == 0:
283
+ save_path = self.get_save_path("point_cloud.ply")
284
+ self.geometry.save_ply(save_path)
285
+
286
+ def on_test_epoch_end(self):
287
+ self.save_img_sequence(
288
+ f"it{self.global_step}-test",
289
+ f"it{self.global_step}-test",
290
+ "(\d+)\.png",
291
+ save_format="mp4",
292
+ fps=30,
293
+ name="test",
294
+ step=self.global_step,
295
+ )
296
+
297
+ def on_load_checkpoint(self, ckpt_dict) -> None:
298
+ for key in self.state_dict().keys():
299
+ if 'sam' in key:
300
+ ckpt_dict["state_dict"][key] = self.state_dict()[key]
301
+
302
+ num_pts = ckpt_dict["state_dict"]["geometry._xyz"].shape[0]
303
+ pcd = BasicPointCloud(
304
+ points=np.zeros((num_pts, 3)),
305
+ colors=np.zeros((num_pts, 3)),
306
+ normals=np.zeros((num_pts, 3)),
307
+ )
308
+ self.geometry.create_from_pcd(pcd, 10)
309
+ self.geometry.training_setup()
310
+ super().on_load_checkpoint(ckpt_dict)
311
+
312
+ def outpaint(self) -> None:
313
+ threestudio.info("Start outpainting.")
314
+ self.outpaint_view = dict()
315
+ self.outpaint_mask = dict()
316
+ cnt = 0
317
+ for view_idx in [0, -1]:
318
+ self.geometry._opacity_mask = None
319
+ sample = self.trainer.val_dataloaders.dataset[view_idx]
320
+ for k in sample.keys():
321
+ try:
322
+ sample[k] = sample[k].cuda()[None]
323
+ except:
324
+ pass
325
+ output = self(sample)
326
+ rgb = (output['comp_rgb'][0] * 255).detach().cpu().numpy().astype(np.uint8)
327
+ rgb = Image.fromarray(rgb)
328
+ mask, depth = self.geometry.project_pc(sample['c2w'], H=512, W=512)
329
+ mask = ~mask[0].cpu().numpy()
330
+ mask = Image.fromarray(mask)
331
+ c2w = sample['c2w']
332
+ rgb, mask = self.geometry.add_pc_from_novel_view(rgb, mask, depth, c2w, save_path=os.path.join(self._save_dir[:-4], f'{cnt}.ply'))
333
+ rgb.save(os.path.join(self._save_dir[:-4], f"outpaint_{cnt}.png"))
334
+ mask.save(os.path.join(self._save_dir[:-4], f"mask_{cnt}.png"))
335
+ cnt += 1
336
+ self.outpaint_view[view_idx] = torch.tensor(np.array(rgb), device='cuda')[None] / 255
337
+ self.outpaint_mask[view_idx] = torch.tensor(np.array(mask).astype(float), device='cuda')[None, ..., None]
338
+
339
+ def distill_language_feature(self) -> None:
340
+ threestudio.info("Start distilling language feature.")
341
+ self.geometry._opacity_mask = None
342
+ total_embed = []
343
+ total_feat = []
344
+ total_flag = []
345
+
346
+ for idx in trange(0, len(self.trainer.val_dataloaders.dataset), self.cfg.distill_interval):
347
+ sample = self.trainer.val_dataloaders.dataset[idx]
348
+ for k in sample.keys():
349
+ try:
350
+ sample[k] = sample[k].cuda()[None]
351
+ except:
352
+ pass
353
+ output = self(sample)
354
+ rgb = output['comp_rgb'] #shape: 1, 512, 512, 3
355
+ rgb = (rgb.permute(0, 3, 1, 2) * 255).type(torch.uint8)
356
+
357
+ try:
358
+ embed, seg, mask= self.sam_clip(rgb) # feat's shape: N * H * W
359
+ total_embed.append(embed)
360
+ total_feat.append(seg)
361
+ total_flag.append(idx)
362
+ except:
363
+ threestudio.info(f'except caught during language distillation at {idx}')
364
+ pass
365
+
366
+ # train VAE
367
+ threestudio.info("Start training autoencoder.")
368
+ dataset = Autoencoder_dataset(torch.cat(total_embed, 0).float().numpy())
369
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0, drop_last=False)
370
+ optimizer = torch.optim.Adam(self.sam_clip_ae.parameters(), lr=self.cfg.sam_clip_ae_lr)
371
+
372
+ self.sam_clip_ae.train()
373
+ for epoch in tqdm(range(self.cfg.ae_epoch)):
374
+ for idx, data in enumerate(dataloader):
375
+ data = data.cuda()
376
+ mid = self.sam_clip_ae.encode(data)
377
+ _data = self.sam_clip_ae.decode(mid)
378
+ l2loss = l2_loss(_data, data)
379
+ cosloss = cos_loss(_data, data)
380
+ loss = l2loss + cosloss * 0.001
381
+ optimizer.zero_grad()
382
+ loss.backward()
383
+ optimizer.step()
384
+
385
+ self.sam_clip_ae.eval()
386
+ mids = dict()
387
+ with torch.no_grad():
388
+ zero_tensor = torch.zeros([1, 512], dtype=float)
389
+ for idx, seg, embed in zip(total_flag, total_feat, total_embed):
390
+ embeds = torch.cat([embed, zero_tensor], 0).float().cuda()
391
+ embeds = self.sam_clip_ae.encode(embeds)
392
+ mid = embeds[seg[:]].squeeze(0).reshape(self.dataset.cfg.height, self.dataset.cfg.width, -1)
393
+ mids[idx] = mid
394
+ rgb = ((mid - mid.min()) / (mid.max() - mid.min())).cpu()
395
+ if self.sam_clip.cfg.vis_pca_feature:
396
+ self.save_image_grid(f"it{self.global_step}-ae/{idx}.png",
397
+ [
398
+ {
399
+ "type": "rgb",
400
+ "img": rgb,
401
+ "kwargs": {"data_format": "HWC"},
402
+ },
403
+ ],
404
+ name="ae",
405
+ step=self.global_step,
406
+ )
407
+
408
+ if self.sam_clip.cfg.vis_pca_feature:
409
+ self.save_img_sequence(
410
+ f"it{self.global_step}-ae",
411
+ f"it{self.global_step}-ae",
412
+ "(\d+)\.png",
413
+ save_format="mp4",
414
+ fps=30,
415
+ name="ae",
416
+ step=self.global_step,
417
+ )
418
+
419
+ threestudio.info("Start training Lang feature.")
420
+ # distill lang feature
421
+ self.geometry.lang_training_setup()
422
+ opt = self.geometry.lang_optimizer
423
+
424
+ idx_list = list(mids.keys())
425
+ sample_dict = dict()
426
+
427
+ for idx, sample in enumerate(self.trainer.val_dataloaders.dataset):
428
+ for k in sample.keys():
429
+ try:
430
+ sample[k] = sample[k].cuda()[None]
431
+ except:
432
+ pass
433
+ sample_dict[idx] = sample
434
+
435
+ for epoch in trange(self.cfg.distill_lang_epoch):
436
+ random.shuffle(idx_list)
437
+ for idx in idx_list:
438
+ sample = sample_dict[idx]
439
+ lang = self(sample)["lang"]
440
+ mid = mids[idx][None]
441
+ loss = l2_loss(mid, lang)
442
+ opt.zero_grad()
443
+ loss.backward()
444
+ opt.step()
445
+ if (epoch + 1) % 30 == 0:
446
+ opt.state = collections.defaultdict(dict)
447
+
448
+ self.renderer.training=False
449
+ with torch.no_grad():
450
+ lang_min, lang_max = None, None
451
+ for idx, sample in sample_dict.items():
452
+ lang = self(sample)["lang"][0]
453
+ if lang_min is None:
454
+ lang_min, lang_max = lang.min().item(), lang.max().item()
455
+ self.save_image_grid(f"it{self.global_step}-feat/{idx}.png",
456
+ [
457
+ {
458
+ "type": "rgb",
459
+ "img": lang,
460
+ "kwargs": {"data_format": "HWC", "data_range": (lang_min, lang_max)},
461
+ },
462
+ ],
463
+ name=f"feat",
464
+ step=self.global_step,
465
+ )
466
+ self.renderer.training=True
467
+
468
+ self.save_img_sequence(
469
+ f"it{self.global_step}-feat",
470
+ f"it{self.global_step}-feat",
471
+ "(\d+)\.png",
472
+ save_format="mp4",
473
+ fps=30,
474
+ name=f"feat",
475
+ step=self.global_step,
476
+ )
477
+
478
+ self.geometry.training_setup()
479
+
480
+ threestudio.info("Use Lang feature to crop pts")
481
+ if self.cfg.crop_with_lang:
482
+ p = 2
483
+ if self.geometry._delete_mask is None:
484
+ self.geometry._delete_mask = torch.ones_like(self.geometry.ooi_masks[0])
485
+ for ooi_idx, ooi_mask in enumerate(self.geometry.ooi_masks):
486
+ threestudio.info(self.geometry.ooi_masks[ooi_idx].sum())
487
+ idx = torch.arange(len(ooi_mask), device='cuda')[ooi_mask.bool()]
488
+ lang_feat = self.geometry.get_language_feature[ooi_mask.bool()]
489
+ lang_feat = lang_feat / (lang_feat.norm(2, dim=-1, keepdim=True) + 0.1)
490
+
491
+ original_ooi_mask = ooi_mask.clone()
492
+ # filter with color by KMeans
493
+ kmeans = KMeans(n_init='auto', n_clusters=10)
494
+ kmeans.fit(lang_feat.detach().cpu())
495
+ labels = kmeans.labels_
496
+ _ = [(labels==i).sum() for i in np.unique(labels)]
497
+ max_label = _.index(max(_))
498
+ dist = ((kmeans.cluster_centers_ - kmeans.cluster_centers_[max_label:max_label+1]) **2).sum(-1)**.5
499
+
500
+ for label, num in enumerate(_):
501
+ if dist[label] > 0.3:
502
+ ooi_mask[idx[labels == label]] = False
503
+ self.geometry._delete_mask[idx[labels == label]] = 0.
504
+
505
+ p = 1
506
+ # filter with color by Gaussian
507
+ mean, std = lang_feat.mean(0), lang_feat.std(0)
508
+ outlier = torch.logical_or(lang_feat < mean - p * std, lang_feat > mean + p * std).sum(-1) > 0
509
+ ooi_mask[idx[outlier]] = False
510
+ self.geometry._delete_mask[idx[outlier]] = 0.
511
+
512
+ p = 3
513
+ # filter with RGB by Gaussian
514
+ rgb =self.geometry.get_features[original_ooi_mask.bool()][:, 0]
515
+ mean, std = rgb.mean(0), rgb.std(0)
516
+ outlier = torch.logical_or(rgb < mean - p * std, rgb > mean + p * std).sum(-1) > 0
517
+ ooi_mask[idx[outlier]] = False
518
+ self.geometry._delete_mask[idx[outlier]] = 0.
519
+
520
+ def load_state_dict(self, state_dict, strict=True):
521
+ i = 0
522
+ while 1:
523
+ if f'geometry.ooi_masks_{i}' not in state_dict.keys():
524
+ break
525
+ self.geometry.register_buffer(f'ooi_masks_{i}', state_dict[f'geometry.ooi_masks_{i}'])
526
+ i += 1
527
+ self.geometry.register_buffer('_delete_mask', state_dict['geometry._delete_mask'])
528
+ return super().load_state_dict(state_dict, strict)
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/utils/__init__.py ADDED
File without changes
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/utils/ae.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset
6
+
7
+ class Autoencoder_dataset(Dataset):
8
+ def __init__(self, data):
9
+ self.data = data
10
+
11
+ def __getitem__(self, index):
12
+ data = torch.tensor(self.data[index])
13
+ return data
14
+
15
+ def __len__(self):
16
+ return self.data.shape[0]
17
+
18
+
19
+ class Autoencoder(nn.Module):
20
+ def __init__(self, encoder_hidden_dims, decoder_hidden_dims):
21
+ super(Autoencoder, self).__init__()
22
+ encoder_layers = []
23
+ for i in range(len(encoder_hidden_dims)):
24
+ if i == 0:
25
+ encoder_layers.append(nn.Linear(512, encoder_hidden_dims[i]))
26
+ else:
27
+ encoder_layers.append(torch.nn.GroupNorm(2, encoder_hidden_dims[i-1]))
28
+ # encoder_layers.append(torch.nn.BatchNorm1d(encoder_hidden_dims[i-1]))
29
+ encoder_layers.append(nn.ReLU())
30
+ encoder_layers.append(nn.Linear(encoder_hidden_dims[i-1], encoder_hidden_dims[i]))
31
+ self.encoder = nn.ModuleList(encoder_layers)
32
+
33
+ decoder_layers = []
34
+ for i in range(len(decoder_hidden_dims)):
35
+ if i == 0:
36
+ decoder_layers.append(nn.Linear(encoder_hidden_dims[-1], decoder_hidden_dims[i]))
37
+ else:
38
+ encoder_layers.append(torch.nn.GroupNorm(2, decoder_hidden_dims[i-1]))
39
+ # encoder_layers.append(torch.nn.BatchNorm1d(decoder_hidden_dims[i-1]))
40
+ decoder_layers.append(nn.ReLU())
41
+ decoder_layers.append(nn.Linear(decoder_hidden_dims[i-1], decoder_hidden_dims[i]))
42
+ self.decoder = nn.ModuleList(decoder_layers)
43
+
44
+ def forward(self, x):
45
+ for m in self.encoder:
46
+ x = m(x)
47
+ x = x / x.norm(2, dim=-1, keepdim=True)
48
+ for m in self.decoder:
49
+ x = m(x)
50
+ # x = x / x.norm(2, dim=-1, keepdim=True)
51
+ return x
52
+
53
+ def encode(self, x):
54
+ for m in self.encoder:
55
+ x = m(x)
56
+ x = x / x.norm(2, dim=-1, keepdim=True)
57
+ return x
58
+
59
+ def decode(self, x):
60
+ for m in self.decoder:
61
+ x = m(x)
62
+ # x = x / x.norm(2, dim=-1, keepdim=True)
63
+ return x
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/custom/threestudio-3dgs/utils/sam_clip.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import pytorch_lightning as pl
3
+ from threestudio.utils.config import parse_structured
4
+ from threestudio.utils.base import Updateable, update_if_possible
5
+ from threestudio.utils.saving import SaverMixin
6
+ from threestudio.utils.typing import *
7
+
8
+ import open_clip
9
+ import torch
10
+ import torchvision
11
+ from torch import nn
12
+ import cv2
13
+ import numpy as np
14
+ from sklearn.decomposition import PCA
15
+
16
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
17
+ from mobile_sam import sam_model_registry as m_sam_model_registry
18
+ from mobile_sam import SamAutomaticMaskGenerator as m_SamAutomaticMaskGenerator
19
+ from mobile_sam import SamPredictor as m_SamPredictor
20
+
21
+ @dataclass
22
+ class OpenCLIPNetworkConfig:
23
+ _target: Type = field(default_factory=lambda: OpenCLIPNetwork)
24
+ clip_model_type: str = "ViT-B-16"
25
+ clip_model_pretrained: str = "laion2b_s34b_b88k"
26
+ clip_n_dims: int = 512
27
+ negatives: Tuple[str] = ("object", "things", "stuff", "texture")
28
+ positives: Tuple[str] = ("",)
29
+
30
+ class OpenCLIPNetwork(nn.Module):
31
+ def __init__(self, config: OpenCLIPNetworkConfig):
32
+ super().__init__()
33
+ self.config = config
34
+ self.process = torchvision.transforms.Compose(
35
+ [
36
+ torchvision.transforms.Resize((224, 224)),
37
+ torchvision.transforms.Normalize(
38
+ mean=[0.48145466, 0.4578275, 0.40821073],
39
+ std=[0.26862954, 0.26130258, 0.27577711],
40
+ ),
41
+ ]
42
+ )
43
+ model, _, _ = open_clip.create_model_and_transforms(
44
+ self.config.clip_model_type, # e.g., ViT-B-16
45
+ pretrained=self.config.clip_model_pretrained, # e.g., laion2b_s34b_b88k
46
+ precision="fp16",
47
+ )
48
+ model.eval()
49
+ self.tokenizer = open_clip.get_tokenizer(self.config.clip_model_type)
50
+ self.model = model.to("cuda")
51
+ self.clip_n_dims = self.config.clip_n_dims
52
+
53
+ self.positives = self.config.positives
54
+ self.negatives = self.config.negatives
55
+ with torch.no_grad():
56
+ tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to("cuda")
57
+ self.pos_embeds = model.encode_text(tok_phrases)
58
+ tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.negatives]).to("cuda")
59
+ self.neg_embeds = model.encode_text(tok_phrases)
60
+ self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True)
61
+ self.neg_embeds /= self.neg_embeds.norm(dim=-1, keepdim=True)
62
+
63
+ assert (
64
+ self.pos_embeds.shape[1] == self.neg_embeds.shape[1]
65
+ ), "Positive and negative embeddings must have the same dimensionality"
66
+ assert (
67
+ self.pos_embeds.shape[1] == self.clip_n_dims
68
+ ), "Embedding dimensionality must match the model dimensionality"
69
+
70
+ @property
71
+ def name(self) -> str:
72
+ return "openclip_{}_{}".format(self.config.clip_model_type, self.config.clip_model_pretrained)
73
+
74
+ @property
75
+ def embedding_dim(self) -> int:
76
+ return self.config.clip_n_dims
77
+
78
+ def gui_cb(self,element):
79
+ self.set_positives(element.value.split(";"))
80
+
81
+ def set_positives(self, text_list):
82
+ self.positives = text_list
83
+ with torch.no_grad():
84
+ tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to("cuda")
85
+ self.pos_embeds = self.model.encode_text(tok_phrases)
86
+ self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True)
87
+
88
+ def get_relevancy(self, embed: torch.Tensor, positive_id: int) -> torch.Tensor:
89
+ phrases_embeds = torch.cat([self.pos_embeds, self.neg_embeds], dim=0)
90
+ p = phrases_embeds.to(embed.dtype) # phrases x 512
91
+ output = torch.mm(embed, p.T) # rays x phrases
92
+ positive_vals = output[..., positive_id : positive_id + 1] # rays x 1
93
+ negative_vals = output[..., len(self.positives) :] # rays x N_phrase
94
+ repeated_pos = positive_vals.repeat(1, len(self.negatives)) # rays x N_phrase
95
+
96
+ sims = torch.stack((repeated_pos, negative_vals), dim=-1) # rays x N-phrase x 2
97
+ softmax = torch.softmax(10 * sims, dim=-1) # rays x n-phrase x 2
98
+ best_id = softmax[..., 0].argmin(dim=1) # rays x 2
99
+ return torch.gather(softmax, 1, best_id[..., None, None].expand(best_id.shape[0], len(self.negatives), 2))[:, 0, :]
100
+
101
+ def encode_image(self, input):
102
+ processed_input = self.process(input).half()
103
+ return self.model.encode_image(processed_input)
104
+
105
+ def get_seg_img(mask, image):
106
+ image = image.copy()
107
+ image[mask['segmentation']==0] = np.array([0, 0, 0], dtype=np.uint8)
108
+ x,y,w,h = np.int32(mask['bbox'])
109
+ seg_img = image[y:y+h, x:x+w, ...]
110
+ return seg_img
111
+
112
+ def pad_img(img):
113
+ h, w, _ = img.shape
114
+ l = max(w,h)
115
+ pad = np.zeros((l,l,3), dtype=np.uint8)
116
+ if h > w:
117
+ pad[:,(h-w)//2:(h-w)//2 + w, :] = img
118
+ else:
119
+ pad[(w-h)//2:(w-h)//2 + h, :, :] = img
120
+ return pad
121
+
122
+ def filter(keep: torch.Tensor, masks_result) -> None:
123
+ keep = keep.int().cpu().numpy()
124
+ result_keep = []
125
+ for i, m in enumerate(masks_result):
126
+ if i in keep: result_keep.append(m)
127
+ return result_keep
128
+
129
+ def sava_numpy(save_path, data):
130
+ save_path_s = save_path + '_s.npy'
131
+ save_path_f = save_path + '_f.npy'
132
+ np.save(save_path_s, data['seg_maps'].numpy())
133
+ np.save(save_path_f, data['feature'].numpy())
134
+
135
+ def mask_nms(masks, scores, iou_thr=0.7, score_thr=0.1, inner_thr=0.2, **kwargs):
136
+ """
137
+ Perform mask non-maximum suppression (NMS) on a set of masks based on their scores.
138
+
139
+ Args:
140
+ masks (torch.Tensor): has shape (num_masks, H, W)
141
+ scores (torch.Tensor): The scores of the masks, has shape (num_masks,)
142
+ iou_thr (float, optional): The threshold for IoU.
143
+ score_thr (float, optional): The threshold for the mask scores.
144
+ inner_thr (float, optional): The threshold for the overlap rate.
145
+ **kwargs: Additional keyword arguments.
146
+ Returns:
147
+ selected_idx (torch.Tensor): A tensor representing the selected indices of the masks after NMS.
148
+ """
149
+
150
+ scores, idx = scores.sort(0, descending=True)
151
+ num_masks = idx.shape[0]
152
+
153
+ masks_ord = masks[idx.view(-1), :]
154
+ masks_area = torch.sum(masks_ord, dim=(1, 2), dtype=torch.float)
155
+
156
+ iou_matrix = torch.zeros((num_masks,) * 2, dtype=torch.float, device=masks.device)
157
+ inner_iou_matrix = torch.zeros((num_masks,) * 2, dtype=torch.float, device=masks.device)
158
+ for i in range(num_masks):
159
+ for j in range(i, num_masks):
160
+ intersection = torch.sum(torch.logical_and(masks_ord[i], masks_ord[j]), dtype=torch.float)
161
+ union = torch.sum(torch.logical_or(masks_ord[i], masks_ord[j]), dtype=torch.float)
162
+ iou = intersection / union
163
+ iou_matrix[i, j] = iou
164
+ # select mask pairs that may have a severe internal relationship
165
+ if intersection / masks_area[i] < 0.5 and intersection / masks_area[j] >= 0.85:
166
+ inner_iou = 1 - (intersection / masks_area[j]) * (intersection / masks_area[i])
167
+ inner_iou_matrix[i, j] = inner_iou
168
+ if intersection / masks_area[i] >= 0.85 and intersection / masks_area[j] < 0.5:
169
+ inner_iou = 1 - (intersection / masks_area[j]) * (intersection / masks_area[i])
170
+ inner_iou_matrix[j, i] = inner_iou
171
+
172
+ iou_matrix.triu_(diagonal=1)
173
+ iou_max, _ = iou_matrix.max(dim=0)
174
+ inner_iou_matrix_u = torch.triu(inner_iou_matrix, diagonal=1)
175
+ inner_iou_max_u, _ = inner_iou_matrix_u.max(dim=0)
176
+ inner_iou_matrix_l = torch.tril(inner_iou_matrix, diagonal=1)
177
+ inner_iou_max_l, _ = inner_iou_matrix_l.max(dim=0)
178
+
179
+ keep = iou_max <= iou_thr
180
+ keep_conf = scores > score_thr
181
+ keep_inner_u = inner_iou_max_u <= 1 - inner_thr
182
+ keep_inner_l = inner_iou_max_l <= 1 - inner_thr
183
+
184
+ # If there are no masks with scores above threshold, the top 3 masks are selected
185
+ if keep_conf.sum() == 0:
186
+ index = scores.topk(3).indices
187
+ keep_conf[index, 0] = True
188
+ if keep_inner_u.sum() == 0:
189
+ index = scores.topk(3).indices
190
+ keep_inner_u[index, 0] = True
191
+ if keep_inner_l.sum() == 0:
192
+ index = scores.topk(3).indices
193
+ keep_inner_l[index, 0] = True
194
+ keep *= keep_conf
195
+ keep *= keep_inner_u
196
+ keep *= keep_inner_l
197
+
198
+ selected_idx = idx[keep]
199
+ return selected_idx
200
+
201
+ def masks_update(*args, **kwargs):
202
+ # remove redundant masks based on the scores and overlap rate between masks
203
+ masks_new = ()
204
+ for masks_lvl in (args):
205
+ seg_pred = torch.from_numpy(np.stack([m['segmentation'] for m in masks_lvl], axis=0))
206
+ iou_pred = torch.from_numpy(np.stack([m['predicted_iou'] for m in masks_lvl], axis=0))
207
+ stability = torch.from_numpy(np.stack([m['stability_score'] for m in masks_lvl], axis=0))
208
+
209
+ scores = stability * iou_pred
210
+ keep_mask_nms = mask_nms(seg_pred, scores, **kwargs)
211
+ masks_lvl = filter(keep_mask_nms, masks_lvl)
212
+
213
+ masks_new += (masks_lvl,)
214
+ return masks_new
215
+
216
+ def sam_encoder(image, mask_generator):
217
+ image = image.detach().cpu()
218
+ image = cv2.cvtColor(image[0].permute(1,2,0).numpy().astype(np.uint8), cv2.COLOR_BGR2RGB)
219
+ # pre-compute masks
220
+ masks_l = mask_generator.generate(image)
221
+ # pre-compute postprocess
222
+ masks_l = masks_update(masks_l, iou_thr=0.8, score_thr=0.7, inner_thr=0.5)[0]
223
+
224
+ def mask2segmap(masks, image):
225
+ seg_img_list = []
226
+ seg_map = -np.ones(image.shape[:2], dtype=np.int32)
227
+ for i in range(len(masks)):
228
+ mask = masks[i]
229
+ seg_img = get_seg_img(mask, image)
230
+ pad_seg_img = cv2.resize(pad_img(seg_img), (224,224))
231
+ seg_img_list.append(pad_seg_img)
232
+
233
+ seg_map[masks[i]['segmentation']] = i
234
+ seg_imgs = np.stack(seg_img_list, axis=0) # b,H,W,3
235
+ seg_imgs = (torch.from_numpy(seg_imgs.astype("float32")).permute(0,3,1,2) / 255.0).to('cuda')
236
+
237
+ return seg_imgs, seg_map
238
+
239
+ seg_images, seg_maps = {}, {}
240
+ seg_images['l'], seg_maps['l'] = mask2segmap(masks_l, image)
241
+
242
+ # 0:default 1:s 2:m 3:l
243
+ return seg_images, seg_maps
244
+
245
+ class SamClip(pl.LightningModule, Updateable, SaverMixin):
246
+ @dataclass
247
+ class Config:
248
+ clip_model_type: str = "ViT-B-16"
249
+ clip_model_pretrained: str = "laion2b_s34b_b88k"
250
+ clip_n_dims: int = 512
251
+ sam_ckpt_path: str = "ckpts/sam_vit_h_4b8939.pth"
252
+ feature_level: int = 3
253
+ vis_pca_feature: bool = True
254
+ use_mobile_sam: bool = True
255
+
256
+ cfg: Config
257
+
258
+ def __init__(self, cfg) -> None:
259
+ super().__init__()
260
+ self.cfg = parse_structured(self.Config, cfg)
261
+ self.model = OpenCLIPNetwork(OpenCLIPNetworkConfig)
262
+ self.clip_n_dims = self.cfg.clip_n_dims
263
+ self.tokenizer = open_clip.get_tokenizer(self.cfg.clip_model_type)
264
+ sam = sam_model_registry["vit_h"](checkpoint=self.cfg.sam_ckpt_path).to('cuda')
265
+ self.mask_generator = SamAutomaticMaskGenerator(
266
+ model=sam,
267
+ points_per_side=32,
268
+ points_per_batch=64,
269
+ pred_iou_thresh=0.7,
270
+ box_nms_thresh=0.7,
271
+ stability_score_thresh=0.85,
272
+ crop_n_layers=1,
273
+ crop_n_points_downscale_factor=1,
274
+ min_mask_region_area=100,
275
+ )
276
+
277
+ model_type = "vit_t"
278
+ sam_checkpoint = "./ckpts/mobile_sam.pt"
279
+ device = "cuda" if torch.cuda.is_available() else "cpu"
280
+ mobile_sam = m_sam_model_registry[model_type](checkpoint=sam_checkpoint)
281
+ mobile_sam.to(device=device)
282
+ mobile_sam.eval()
283
+ self.m_mask_generator = m_SamAutomaticMaskGenerator(mobile_sam)
284
+
285
+ # self.estimator = PCA(n_components=3)
286
+ # self.has_fit = False
287
+
288
+ self.mask_generator.predictor.model.to('cuda')
289
+ self.m_mask_generator.predictor.model.to('cuda')
290
+
291
+ def _embed_clip_sam_tiles(self, image, sam_encoder):
292
+ aug_imgs = torch.cat([image])
293
+ if self.cfg.use_mobile_sam:
294
+ seg_images, seg_map = sam_encoder(aug_imgs, self.m_mask_generator)
295
+ else:
296
+ seg_images, seg_map = sam_encoder(aug_imgs, self.mask_generator)
297
+
298
+ clip_embeds = {}
299
+ # types = ['default', 's', 'm', 'l']
300
+ types = ['l']
301
+ for mode in types:
302
+ tiles = seg_images[mode]
303
+ tiles = tiles.to("cuda")
304
+ with torch.no_grad():
305
+ clip_embed = self.model.encode_image(tiles)
306
+ clip_embed /= clip_embed.norm(dim=-1, keepdim=True)
307
+ clip_embeds[mode] = clip_embed.detach().cpu().half()
308
+
309
+ return clip_embeds, seg_map
310
+
311
+ def forward(self, img):
312
+ embed_size=512
313
+ seg_maps = []
314
+ total_lengths = []
315
+ timer = 0
316
+ img_embeds = torch.zeros((len(img), 100, embed_size))
317
+
318
+ seg_maps = torch.zeros((len(img), 1, *img.shape[2:]))
319
+ img_embed, seg_map = self._embed_clip_sam_tiles(img, sam_encoder)
320
+
321
+ lengths = [len(v) for k, v in img_embed.items()]
322
+ total_length = sum(lengths)
323
+ # total_lengths.append(total_length)
324
+
325
+ # if total_length > img_embeds.shape[1]:
326
+ # pad = total_length - img_embeds.shape[1]
327
+ # img_embeds = torch.cat([
328
+ # img_embeds,
329
+ # torch.zeros((len(image_list), pad, embed_size))
330
+ # ], dim=1)
331
+
332
+ # img_embed = torch.cat([v for k, v in img_embed.items()], dim=0)
333
+ # assert img_embed.shape[0] == total_length
334
+ img_embeds[0, :total_length] = img_embed['l']
335
+
336
+ # seg_map_tensor = []
337
+ # lengths_cumsum = lengths.copy()
338
+ # for j in range(1, len(lengths)):
339
+ # lengths_cumsum[j] += lengths_cumsum[j-1]
340
+ # for j, (k, v) in enumerate(seg_map.items()):
341
+ # if j == 0:
342
+ # seg_map_tensor.append(torch.from_numpy(v))
343
+ # continue
344
+ # assert v.max() == lengths[j] - 1, f"{j}, {v.max()}, {lengths[j]-1}"
345
+ # v[v != -1] += lengths_cumsum[j-1]
346
+ # seg_map_tensor.append(torch.from_numpy(v))
347
+ # seg_map = torch.stack(seg_map_tensor, dim=0)
348
+ seg_maps[0] = torch.from_numpy(seg_map['l'])
349
+
350
+ # self.mask_generator.predictor.model.to('cpu')
351
+ feature_map = img_embeds[0] # 300, 512
352
+ seg_map = seg_maps[0] # 4, 512, 512
353
+
354
+ image_height, image_width = seg_map.shape[1:]
355
+ y, x = torch.meshgrid(torch.arange(0, image_height), torch.arange(0, image_width))
356
+ x = x.reshape(-1, 1)
357
+ y = y.reshape(-1, 1)
358
+ seg = seg_map[:, y, x].squeeze(-1).long()
359
+ mask = seg != -1
360
+ point_feature1 = feature_map[seg[:]].squeeze(0)
361
+ mask = mask[:].reshape(1, image_height, image_width)
362
+ return img_embed['l'], seg, mask
363
+ # point_feature = point_feature1.reshape(image_height, image_width, -1).permute(2, 0, 1)
364
+
365
+ # return img_embed['l'], point_feature, mask
366
+
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/bear_background.png ADDED

Git LFS Details

  • SHA256: 950496f640077d2d1b3f28cf8f2ecaeb56bc641b2c19f6a8107e6d428f5da17f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/bear_composite.png ADDED

Git LFS Details

  • SHA256: 1445582663ddb516915adbfac9f33ba2d95e554d76a1f4b164ef9f119061be74
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/bear_layers.png ADDED
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/boy_background.png ADDED

Git LFS Details

  • SHA256: e7221341ebcc6084cf6ef9521324bea45658cebb9a3a4de487ef0f17bd83235a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/boy_composite.png ADDED

Git LFS Details

  • SHA256: 4f9a04bee8f5de415251558a4401c7f597ee3e6c2dc989ddf974a9104243c8dc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/boy_layers.png ADDED
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/corgi_background.png ADDED

Git LFS Details

  • SHA256: 2e7c7c2ab126d4d26c2258160d859e03291ef745fae2e05540ac60bc8976e7d9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/corgi_composite.png ADDED

Git LFS Details

  • SHA256: 6c00e3156ad6e929df5abe236b7d98772b13142551e740866317e5e829f3bf03
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
000000000017.1/gs-sds-generation/3DitScene@20250207-015119/code/examples/corgi_layers.png ADDED