Paul Engstler commited on
Commit
84eee5b
·
0 Parent(s):

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +148 -0
  3. README.md +10 -0
  4. app.py +257 -0
  5. examples/photo-1469559845082-95b66baaf023.jpeg +0 -0
  6. examples/photo-1499916078039-922301b0eb9b.jpeg +0 -0
  7. examples/photo-1514984879728-be0aff75a6e8.jpeg +0 -0
  8. examples/photo-1546975490-e8b92a360b24.jpeg +0 -0
  9. examples/photo-1618197345638-d2df92b39fe1.jpeg +0 -0
  10. examples/photo-1628624747186-a941c476b7ef.jpeg +0 -0
  11. examples/photo-1667788000333-4e36f948de9a.jpeg +0 -0
  12. packages.txt +1 -0
  13. pre-requirements.txt +0 -0
  14. requirements.txt +26 -0
  15. utils/demo.py +54 -0
  16. utils/gaussian_renderer/__init__.py +100 -0
  17. utils/gaussian_renderer/network_gui.py +86 -0
  18. utils/gs.py +196 -0
  19. utils/models.py +119 -0
  20. utils/ops.py +95 -0
  21. utils/render.py +112 -0
  22. utils/scene/__init__.py +92 -0
  23. utils/scene/cameras.py +76 -0
  24. utils/scene/colmap_loader.py +294 -0
  25. utils/scene/dataset_readers.py +270 -0
  26. utils/scene/gaussian_model.py +416 -0
  27. utils/scene/utils/camera_utils.py +84 -0
  28. utils/scene/utils/general_utils.py +133 -0
  29. utils/scene/utils/graphics_utils.py +88 -0
  30. utils/scene/utils/image_utils.py +19 -0
  31. utils/scene/utils/loss_utils.py +65 -0
  32. utils/scene/utils/sh_utils.py +118 -0
  33. utils/scene/utils/system_utils.py +28 -0
  34. zoedepth/LICENSE +21 -0
  35. zoedepth/data/__init__.py +24 -0
  36. zoedepth/data/data_mono.py +697 -0
  37. zoedepth/data/ddad.py +117 -0
  38. zoedepth/data/diml_indoor_test.py +125 -0
  39. zoedepth/data/diml_outdoor_test.py +114 -0
  40. zoedepth/data/diode.py +125 -0
  41. zoedepth/data/hypersim.py +138 -0
  42. zoedepth/data/ibims.py +81 -0
  43. zoedepth/data/marigold_nyu.py +112 -0
  44. zoedepth/data/places365.py +118 -0
  45. zoedepth/data/preprocess.py +154 -0
  46. zoedepth/data/sun_rgbd_loader.py +106 -0
  47. zoedepth/data/transforms.py +481 -0
  48. zoedepth/data/vkitti.py +151 -0
  49. zoedepth/data/vkitti2.py +187 -0
  50. zoedepth/models/__init__.py +24 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.png
2
+ **.gif
3
+ .vscode/
4
+ *.rdb
5
+ **.xml
6
+ wandb/
7
+ slurm/
8
+ tmp/
9
+ .logs/
10
+ checkpoints/
11
+ external_jobs/
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+ ptlflow_logs/
17
+ output/
18
+ log/
19
+ .idea/
20
+ # C extensions
21
+ *.so
22
+ results/
23
+ **.DS_Store
24
+ **.pt
25
+ demo/
26
+ # Distribution / packaging
27
+ .Python
28
+ build/
29
+ develop-eggs/
30
+ dist/
31
+ downloads/
32
+ eggs/
33
+ .eggs/
34
+ lib/
35
+ lib64/
36
+ parts/
37
+ sdist/
38
+ var/
39
+ wheels/
40
+ pip-wheel-metadata/
41
+ share/python-wheels/
42
+ *.egg-info/
43
+ .installed.cfg
44
+ *.egg
45
+ MANIFEST
46
+ ~shortcuts/
47
+ **/wandb_logs/
48
+ **.db
49
+ # PyInstaller
50
+ # Usually these files are written by a python script from a template
51
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
52
+ *.manifest
53
+ *.spec
54
+
55
+ # Installer logs
56
+ pip-log.txt
57
+ pip-delete-this-directory.txt
58
+
59
+ # Unit test / coverage reports
60
+ htmlcov/
61
+ .tox/
62
+ .nox/
63
+ .coverage
64
+ .coverage.*
65
+ .cache
66
+ nosetests.xml
67
+ coverage.xml
68
+ *.cover
69
+ *.py,cover
70
+ .hypothesis/
71
+ .pytest_cache/
72
+
73
+ # Translations
74
+ *.mo
75
+ *.pot
76
+
77
+ # Django stuff:
78
+ *.log
79
+ local_settings.py
80
+ db.sqlite3
81
+ db.sqlite3-journal
82
+
83
+ # Flask stuff:
84
+ instance/
85
+ .webassets-cache
86
+
87
+ # Scrapy stuff:
88
+ .scrapy
89
+
90
+ # Sphinx documentation
91
+ docs/_build/
92
+
93
+ # PyBuilder
94
+ target/
95
+
96
+ # Jupyter Notebook
97
+ .ipynb_checkpoints
98
+
99
+ # IPython
100
+ profile_default/
101
+ ipython_config.py
102
+
103
+ # pyenv
104
+ .python-version
105
+
106
+ # pipenv
107
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
108
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
109
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
110
+ # install all needed dependencies.
111
+ #Pipfile.lock
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Invisible Stitch
3
+ emoji: 🪡
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
app.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+
4
+ # this is a HF Spaces specific hack, as
5
+ # (i) building pytorch3d with GPU support is a bit tricky here
6
+ # (ii) installing the wheel via requirements.txt breaks ZeroGPU
7
+ os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html")
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+
14
+ import skimage
15
+ from PIL import Image
16
+
17
+ import gradio as gr
18
+
19
+ from utils.render import PointsRendererWithMasks, render
20
+ from utils.ops import snap_high_gradients_to_nn, project_points, get_pointcloud, merge_pointclouds, outpaint_with_depth_estimation
21
+ from utils.gs import gs_options, read_cameras_from_optimization_bundle, Scene, run_gaussian_splatting, get_blank_gs_bundle
22
+
23
+ from pytorch3d.utils import opencv_from_cameras_projection
24
+ from utils.ops import focal2fov, fov2focal
25
+ from utils.models import infer_with_zoe_dc
26
+ from utils.scene import GaussianModel
27
+ from utils.demo import downsample_point_cloud
28
+ from typing import Iterable, Tuple, Dict, Optional
29
+ import itertools
30
+
31
+ from pytorch3d.structures import Pointclouds
32
+ from pytorch3d.renderer import (
33
+ look_at_view_transform,
34
+ PerspectiveCameras,
35
+ )
36
+
37
+ from pytorch3d.io import IO
38
+
39
+ def get_blank_gs_bundle(h, w):
40
+ return {
41
+ "camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w),
42
+ "W": w,
43
+ "H": h,
44
+ "pcd_points": None,
45
+ "pcd_colors": None,
46
+ 'frames': [],
47
+ }
48
+
49
+ @spaces.GPU(duration=30)
50
+ def extrapolate_point_cloud(prompt: str, image_size: Tuple[int, int], look_at_params: Iterable[Tuple[float, float, float, Tuple[float, float, float]]], point_cloud: Pointclouds = None, dry_run: bool = False, discard_mask: bool = False, initial_image: Optional[Image.Image] = None, depth_scaling: float = 1, **render_kwargs):
51
+ w, h = image_size
52
+ optimization_bundle_frames = []
53
+
54
+ for azim, elev, dist, at in look_at_params:
55
+ R, T = look_at_view_transform(device=device, azim=azim, elev=elev, dist=dist, at=at)
56
+ cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)
57
+
58
+ if point_cloud is not None:
59
+ images, masks, depths = render(cameras, point_cloud, **render_kwargs)
60
+
61
+ if not dry_run:
62
+ eroded_mask = skimage.morphology.binary_erosion((depths[0] > 0).cpu().numpy(), footprint=None)#skimage.morphology.disk(1))
63
+ eroded_depth = depths[0].clone()
64
+ eroded_depth[torch.from_numpy(eroded_mask).to(depths.device) <= 0] = 0
65
+
66
+ outpainted_img, aligned_depth = outpaint_with_depth_estimation(images[0], masks[0], eroded_depth, h, w, pipe, zoe_dc_model, prompt, cameras, dilation_size=2, depth_scaling=depth_scaling, generator=torch.Generator(device=pipe.device).manual_seed(0))
67
+
68
+ aligned_depth = torch.from_numpy(aligned_depth).to(device)
69
+
70
+ else:
71
+ # in a dry run, we do not actually outpaint the image
72
+ outpainted_img = Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8))
73
+
74
+ else:
75
+ assert initial_image is not None
76
+ assert not dry_run
77
+
78
+ # jumpstart the point cloud with a regular depth estimation
79
+ t_initial_image = torch.from_numpy(np.asarray(initial_image)/255.).permute(2,0,1).float()
80
+ depth = aligned_depth = infer_with_zoe_dc(zoe_dc_model, t_initial_image, torch.zeros(h, w))
81
+ outpainted_img = initial_image
82
+ images = [t_initial_image.to(device)]
83
+ masks = [torch.ones(h, w, dtype=torch.bool).to(device)]
84
+
85
+ if not dry_run:
86
+ # snap high gradients to nearest neighbor, which eliminates noodle artifacts
87
+ aligned_depth = snap_high_gradients_to_nn(aligned_depth, threshold=12).cpu()
88
+ xy_depth_world = project_points(cameras, aligned_depth)
89
+
90
+ c2w = cameras.get_world_to_view_transform().get_matrix()[0]
91
+
92
+ optimization_bundle_frames.append({
93
+ "image": outpainted_img,
94
+ "mask": masks[0].cpu().numpy(),
95
+ "transform_matrix": c2w.tolist(),
96
+ "azim": azim,
97
+ "elev": elev,
98
+ "dist": dist,
99
+ })
100
+
101
+ if discard_mask:
102
+ optimization_bundle_frames[-1].pop("mask")
103
+
104
+ if not dry_run:
105
+ optimization_bundle_frames[-1]["center_point"] = xy_depth_world[0].mean(dim=0).tolist()
106
+ optimization_bundle_frames[-1]["depth"] = aligned_depth.cpu().numpy()
107
+ optimization_bundle_frames[-1]["mean_depth"] = aligned_depth.mean().item()
108
+
109
+ else:
110
+ # in a dry run, we do not modify the point cloud
111
+ continue
112
+
113
+ rgb = (torch.from_numpy(np.asarray(outpainted_img).copy()).reshape(-1, 3).float() / 255).to(device)
114
+
115
+ if point_cloud is None:
116
+ point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)
117
+
118
+ else:
119
+ # pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
120
+ # in theory, 1 pixel is sufficient but we use 2 to be safe
121
+ masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(2))).to(device)
122
+
123
+ partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])
124
+
125
+ point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])
126
+
127
+ return optimization_bundle_frames, point_cloud
128
+
129
+ @spaces.GPU(duration=30)
130
+ def generate_point_cloud(initial_image: Image.Image, prompt: str):
131
+ image_size = initial_image.size
132
+ w, h = image_size
133
+
134
+ optimization_bundle = get_blank_gs_bundle(h, w)
135
+
136
+ step_size = 25
137
+
138
+ azim_steps = [0, step_size, -step_size]
139
+ look_at_params = [(azim, 0, 0.01, torch.zeros((1, 3))) for azim in azim_steps]
140
+
141
+ optimization_bundle["frames"], point_cloud = extrapolate_point_cloud(prompt, image_size, look_at_params, discard_mask=True, initial_image=initial_image, depth_scaling=0.5, fill_point_cloud_holes=True)
142
+
143
+ optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
144
+ optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
145
+
146
+ return optimization_bundle, point_cloud
147
+
148
+ @spaces.GPU(duration=30)
149
+ def supplement_point_cloud(optimization_bundle: Dict, point_cloud: Pointclouds, prompt: str):
150
+ w, h = optimization_bundle["W"], optimization_bundle["H"]
151
+
152
+ supporting_frames = []
153
+
154
+ for i, frame in enumerate(optimization_bundle["frames"]):
155
+ # skip supporting views
156
+ if frame.get("supporting", False):
157
+ continue
158
+
159
+ center_point = torch.tensor(frame["center_point"]).to(device)
160
+ mean_depth = frame["mean_depth"]
161
+ azim, elev = frame["azim"], frame["elev"]
162
+
163
+ azim_jitters = torch.linspace(-5, 5, 3).tolist()
164
+ elev_jitters = torch.linspace(-5, 5, 3).tolist()
165
+
166
+ # build the product of azim and elev jitters
167
+ camera_jitters = [{"azim": azim + azim_jitter, "elev": elev + elev_jitter} for azim_jitter, elev_jitter in itertools.product(azim_jitters, elev_jitters)]
168
+
169
+ look_at_params = [(camera_jitter["azim"], camera_jitter["elev"], mean_depth, center_point.unsqueeze(0)) for camera_jitter in camera_jitters]
170
+
171
+ local_supporting_frames, point_cloud = extrapolate_point_cloud(prompt, (w, h), look_at_params, point_cloud, dry_run=True, depth_scaling=0.5, antialiasing=3)
172
+
173
+ for local_supporting_frame in local_supporting_frames:
174
+ local_supporting_frame["supporting"] = True
175
+
176
+ supporting_frames.extend(local_supporting_frames)
177
+
178
+ optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
179
+ optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
180
+
181
+ return optimization_bundle, point_cloud
182
+
183
+ @spaces.GPU(duration=30)
184
+ def generate_scene(img: Image.Image, prompt: str):
185
+ assert isinstance(img, Image.Image)
186
+
187
+ # resize image maintaining the aspect ratio so the longest side is 720 pixels
188
+ max_size = 720
189
+ img.thumbnail((max_size, max_size))
190
+
191
+ # crop to ensure the image dimensions are divisible by 8
192
+ img = img.crop((0, 0, img.width - img.width % 8, img.height - img.height % 8))
193
+
194
+ from hashlib import sha1
195
+ from datetime import datetime
196
+
197
+ run_id = sha1(datetime.now().isoformat().encode()).hexdigest()[:6]
198
+
199
+ run_name = f"gradio_{run_id}"
200
+
201
+ gs_optimization_bundle, point_cloud = generate_point_cloud(img, prompt)
202
+
203
+ downsampled_point_cloud = downsample_point_cloud(gs_optimization_bundle, device=device)
204
+
205
+ gs_optimization_bundle["pcd_points"] = downsampled_point_cloud.points_padded()[0].cpu().numpy()
206
+ gs_optimization_bundle["pcd_colors"] = downsampled_point_cloud.features_padded()[0].cpu().numpy()
207
+
208
+ scene = Scene(gs_optimization_bundle, GaussianModel(gs_options.sh_degree), gs_options)
209
+
210
+ scene.gaussians._opacity = torch.ones_like(scene.gaussians._opacity)
211
+ #scene = run_gaussian_splatting(scene, gs_optimization_bundle)
212
+
213
+ # coordinate system transformation
214
+ scene.gaussians._xyz = scene.gaussians._xyz.detach()
215
+ scene.gaussians._xyz[:, 1] = -scene.gaussians._xyz[:, 1]
216
+ scene.gaussians._xyz[:, 2] = -scene.gaussians._xyz[:, 2]
217
+
218
+ save_path = os.path.join("outputs", f"{run_name}.ply")
219
+
220
+ scene.gaussians.save_ply(save_path)
221
+
222
+ return save_path
223
+
224
+ if __name__ == "__main__":
225
+ global device
226
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
227
+
228
+ from utils.models import get_zoe_dc_model, get_sd_pipeline
229
+
230
+ global zoe_dc_model
231
+ from huggingface_hub import hf_hub_download
232
+ zoe_dc_model = get_zoe_dc_model(ckpt_path=hf_hub_download(repo_id="paulengstler/invisible-stitch", filename="invisible-stitch.pt")).to(device)
233
+
234
+ global pipe
235
+ pipe = get_sd_pipeline().to(device)
236
+
237
+ demo = gr.Interface(
238
+ fn=generate_scene,
239
+ inputs=[
240
+ gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil"),
241
+ gr.Textbox(label="Scene Hallucination Prompt")
242
+ ],
243
+ outputs=gr.Model3D(label="Generated Scene"),
244
+ allow_flagging="never",
245
+ title="Invisible Stitch: Generating Smooth 3D Scenes with Depth Inpainting",
246
+ description="Hallucinate geometrically coherent 3D scenes from a single input image in less than 30 seconds.<br /> [Project Page](https://research.paulengstler.com/invisible-stitch) | [GitHub](https://github.com/paulengstler/invisible-stitch) | [Paper](#) <br /><br />To keep this demo snappy, we have limited its functionality. Scenes are generated at a low resolution without densification, supporting views are not inpainted, and we do not optimize the resulting point cloud. Imperfections are to be expected, in particular around object borders. Please allow a couple of seconds for the generated scene to be downloaded (about 40 megabytes).",
247
+ article="Please consider running this demo locally to obtain high-quality results (see the GitHub repository).<br /><br />Here are some observations we made that might help you to get better results:<ul><li>Use generic prompts that match the surroundings of your input image.</li><li>Ensure that the borders of your input image are free from partially visible objects.</li><li>Keep your prompts simple and avoid adding specific details.</li></ul>",
248
+ examples=[
249
+ ["examples/photo-1667788000333-4e36f948de9a.jpeg", "a street with traditional buildings in Kyoto, Japan"],
250
+ ["examples/photo-1628624747186-a941c476b7ef.jpeg", "a suburban street in North Carolina on a bright, sunny day"],
251
+ ["examples/photo-1469559845082-95b66baaf023.jpeg", "a view of Zion National Park"],
252
+ ["examples/photo-1514984879728-be0aff75a6e8.jpeg", "a close-up view of a muddy path in a forest"],
253
+ ["examples/photo-1618197345638-d2df92b39fe1.jpeg", "a close-up view of a white linen bed in a minimalistic room"],
254
+ ["examples/photo-1546975490-e8b92a360b24.jpeg", "a warm living room with plants"],
255
+ ["examples/photo-1499916078039-922301b0eb9b.jpeg", "a cozy bedroom on a bright day"],
256
+ ])
257
+ demo.queue().launch(share=True)
examples/photo-1469559845082-95b66baaf023.jpeg ADDED
examples/photo-1499916078039-922301b0eb9b.jpeg ADDED
examples/photo-1514984879728-be0aff75a6e8.jpeg ADDED
examples/photo-1546975490-e8b92a360b24.jpeg ADDED
examples/photo-1618197345638-d2df92b39fe1.jpeg ADDED
examples/photo-1628624747186-a941c476b7ef.jpeg ADDED
examples/photo-1667788000333-4e36f948de9a.jpeg ADDED
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-dev
pre-requirements.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==2.19.0
2
+ diffusers==0.26.3
3
+ fire==0.5.0
4
+ gradio==4.27.0
5
+ h5py==3.10.0
6
+ huggingface_hub==0.22.2
7
+ imageio==2.33.1
8
+ jaxtyping==0.2.28
9
+ matplotlib==3.7.5
10
+ numpy==1.22.4
11
+ opencv_python==4.8.0.76
12
+ pandas==1.5.1
13
+ Pillow==10.3.0
14
+ plyfile==1.0.3
15
+ scipy==1.8.1
16
+ scikit-image
17
+ submitit==1.5.1
18
+ tqdm==4.66.1
19
+ trimesh==3.21.7
20
+ wandb==0.16.3
21
+ xformers==0.0.25
22
+ spaces
23
+ timm==0.6.7
24
+ transformers==4.37.2
25
+ accelerate==0.27.2
26
+ easydict
utils/demo.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import numpy as np
4
+
5
+ import skimage
6
+ from pytorch3d.renderer import (
7
+ look_at_view_transform,
8
+ PerspectiveCameras,
9
+ )
10
+
11
+ from .render import render
12
+ from .ops import project_points, get_pointcloud, merge_pointclouds
13
+
14
+ def downsample_point_cloud(optimization_bundle, device="cpu"):
15
+ point_cloud = None
16
+
17
+ for i, frame in enumerate(optimization_bundle["frames"]):
18
+ if frame.get("supporting", False):
19
+ continue
20
+
21
+ downsampled_image = copy.deepcopy(frame["image"])
22
+ downsampled_image.thumbnail((360, 360))
23
+
24
+ image_size = downsampled_image.size
25
+ w, h = image_size
26
+
27
+ # regenerate the point cloud at a lower resolution
28
+ R, T = look_at_view_transform(device=device, azim=frame["azim"], elev=frame["elev"], dist=frame["dist"])#, dist=1+0.15*step)
29
+ cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)
30
+
31
+ # downsample the depth
32
+ downsampled_depth = torch.nn.functional.interpolate(torch.tensor(frame["depth"]).unsqueeze(0).unsqueeze(0).float().to(device), size=(h, w), mode="nearest").squeeze()
33
+
34
+ xy_depth_world = project_points(cameras, downsampled_depth)
35
+
36
+ rgb = (torch.from_numpy(np.asarray(downsampled_image).copy()).reshape(-1, 3).float() / 255).to(device)
37
+
38
+ c2w = cameras.get_world_to_view_transform().get_matrix()[0]
39
+
40
+ if i == 0:
41
+ point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)
42
+
43
+ else:
44
+ images, masks, depths = render(cameras, point_cloud, radius=1e-2)
45
+
46
+ # pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
47
+ # in theory, 1 pixel is sufficient but we use 2 to be safe
48
+ masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(1))).to(device)
49
+
50
+ partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])
51
+
52
+ point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])
53
+
54
+ return point_cloud
utils/gaussian_renderer/__init__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15
+ from scene.gaussian_model import GaussianModel
16
+ from utils.sh_utils import eval_sh
17
+
18
+ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
19
+ """
20
+ Render the scene.
21
+
22
+ Background tensor (bg_color) must be on GPU!
23
+ """
24
+
25
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
26
+ screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
27
+ try:
28
+ screenspace_points.retain_grad()
29
+ except:
30
+ pass
31
+
32
+ # Set up rasterization configuration
33
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
34
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
35
+
36
+ raster_settings = GaussianRasterizationSettings(
37
+ image_height=int(viewpoint_camera.image_height),
38
+ image_width=int(viewpoint_camera.image_width),
39
+ tanfovx=tanfovx,
40
+ tanfovy=tanfovy,
41
+ bg=bg_color,
42
+ scale_modifier=scaling_modifier,
43
+ viewmatrix=viewpoint_camera.world_view_transform,
44
+ projmatrix=viewpoint_camera.full_proj_transform,
45
+ sh_degree=pc.active_sh_degree,
46
+ campos=viewpoint_camera.camera_center,
47
+ prefiltered=False,
48
+ debug=pipe.debug
49
+ )
50
+
51
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
52
+
53
+ means3D = pc.get_xyz
54
+ means2D = screenspace_points
55
+ opacity = pc.get_opacity
56
+
57
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
58
+ # scaling / rotation by the rasterizer.
59
+ scales = None
60
+ rotations = None
61
+ cov3D_precomp = None
62
+ if pipe.compute_cov3D_python:
63
+ cov3D_precomp = pc.get_covariance(scaling_modifier)
64
+ else:
65
+ scales = pc.get_scaling
66
+ rotations = pc.get_rotation
67
+
68
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
69
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
70
+ shs = None
71
+ colors_precomp = None
72
+ if override_color is None:
73
+ if pipe.convert_SHs_python:
74
+ shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
75
+ dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
76
+ dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
77
+ sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
78
+ colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
79
+ else:
80
+ shs = pc.get_features
81
+ else:
82
+ colors_precomp = override_color
83
+
84
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
85
+ rendered_image, radii = rasterizer(
86
+ means3D = means3D,
87
+ means2D = means2D,
88
+ shs = shs,
89
+ colors_precomp = colors_precomp,
90
+ opacities = opacity,
91
+ scales = scales,
92
+ rotations = rotations,
93
+ cov3D_precomp = cov3D_precomp)
94
+
95
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
96
+ # They will be excluded from value updates used in the splitting criteria.
97
+ return {"render": rendered_image,
98
+ "viewspace_points": screenspace_points,
99
+ "visibility_filter" : radii > 0,
100
+ "radii": radii}
utils/gaussian_renderer/network_gui.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import traceback
14
+ import socket
15
+ import json
16
+ from scene.cameras import MiniCam
17
+
18
+ host = "127.0.0.1"
19
+ port = 6009
20
+
21
+ conn = None
22
+ addr = None
23
+
24
+ listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25
+
26
+ def init(wish_host, wish_port):
27
+ global host, port, listener
28
+ host = wish_host
29
+ port = wish_port
30
+ listener.bind((host, port))
31
+ listener.listen()
32
+ listener.settimeout(0)
33
+
34
+ def try_connect():
35
+ global conn, addr, listener
36
+ try:
37
+ conn, addr = listener.accept()
38
+ print(f"\nConnected by {addr}")
39
+ conn.settimeout(None)
40
+ except Exception as inst:
41
+ pass
42
+
43
+ def read():
44
+ global conn
45
+ messageLength = conn.recv(4)
46
+ messageLength = int.from_bytes(messageLength, 'little')
47
+ message = conn.recv(messageLength)
48
+ return json.loads(message.decode("utf-8"))
49
+
50
+ def send(message_bytes, verify):
51
+ global conn
52
+ if message_bytes != None:
53
+ conn.sendall(message_bytes)
54
+ conn.sendall(len(verify).to_bytes(4, 'little'))
55
+ conn.sendall(bytes(verify, 'ascii'))
56
+
57
+ def receive():
58
+ message = read()
59
+
60
+ width = message["resolution_x"]
61
+ height = message["resolution_y"]
62
+
63
+ if width != 0 and height != 0:
64
+ try:
65
+ do_training = bool(message["train"])
66
+ fovy = message["fov_y"]
67
+ fovx = message["fov_x"]
68
+ znear = message["z_near"]
69
+ zfar = message["z_far"]
70
+ do_shs_python = bool(message["shs_python"])
71
+ do_rot_scale_python = bool(message["rot_scale_python"])
72
+ keep_alive = bool(message["keep_alive"])
73
+ scaling_modifier = message["scaling_modifier"]
74
+ world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75
+ world_view_transform[:,1] = -world_view_transform[:,1]
76
+ world_view_transform[:,2] = -world_view_transform[:,2]
77
+ full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78
+ full_proj_transform[:,1] = -full_proj_transform[:,1]
79
+ custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80
+ except Exception as e:
81
+ print("")
82
+ traceback.print_exc()
83
+ raise e
84
+ return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85
+ else:
86
+ return None, None, None, None, None, None
utils/gs.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+ from .scene import GaussianModel
5
+ from .scene.dataset_readers import SceneInfo, getNerfppNorm
6
+ from .scene.cameras import Camera
7
+ from .ops import focal2fov, fov2focal
8
+ from .scene.gaussian_model import BasicPointCloud
9
+ from easydict import EasyDict as edict
10
+ from PIL import Image
11
+
12
+ from tqdm.auto import tqdm
13
+
14
+ def get_blank_gs_bundle(h, w):
15
+ return {
16
+ "camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w),
17
+ "W": w,
18
+ "H": h,
19
+ "pcd_points": None,
20
+ "pcd_colors": None,
21
+ 'frames': [],
22
+ }
23
+
24
+ def read_cameras_from_optimization_bundle(optimization_bundle, white_background: bool = False):
25
+ cameras = []
26
+
27
+ fovx = optimization_bundle["camera_angle_x"]
28
+ frames = optimization_bundle["frames"]
29
+
30
+ # we flip the x and y axis to move from PyTorch3D's coordinate system to COLMAP's
31
+ coordinate_system_transform = np.array([-1, -1, 1])
32
+
33
+ for idx, frame in enumerate(frames):
34
+ c2w = np.array(frame["transform_matrix"])
35
+ c2w[:3, :3] = c2w[:3, :3] * coordinate_system_transform
36
+
37
+ # get the world-to-camera transform and set R, T
38
+ w2c = np.linalg.inv(c2w)
39
+ R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
40
+ T = c2w[-1, :3] * coordinate_system_transform
41
+
42
+ image = frame["image"]
43
+
44
+ im_data = np.array(image.convert("RGBA"))
45
+
46
+ bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
47
+
48
+ norm_data = im_data / 255.0
49
+ arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
50
+ image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
51
+
52
+ fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
53
+ FovY = fovy
54
+ FovX = fovx
55
+
56
+ image = torch.Tensor(arr).permute(2,0,1)
57
+
58
+ cameras.append(Camera(colmap_id=idx, R=R, T=T, FoVx=FovX, FoVy=FovY, image=image, mask=frame.get("mask", None),
59
+ gt_alpha_mask=None, image_name='', uid=idx, data_device='cuda'))
60
+
61
+ return cameras
62
+
63
+ class Scene:
64
+ gaussians: GaussianModel
65
+
66
+ def __init__(self, traindata, gaussians: GaussianModel, gs_options, shuffle: bool = True):
67
+ self.traindata = traindata
68
+ self.gaussians = gaussians
69
+
70
+ train_cameras = read_cameras_from_optimization_bundle(traindata, gs_options.white_background)
71
+
72
+ nerf_normalization = getNerfppNorm(train_cameras)
73
+
74
+ pcd = BasicPointCloud(points=traindata['pcd_points'], colors=traindata['pcd_colors'], normals=None)
75
+
76
+ scene_info = SceneInfo(point_cloud=pcd,
77
+ train_cameras=train_cameras,
78
+ test_cameras=[],
79
+ nerf_normalization=nerf_normalization,
80
+ ply_path='')
81
+
82
+ if shuffle:
83
+ random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
84
+
85
+ self.cameras_extent = scene_info.nerf_normalization["radius"]
86
+
87
+ self.train_cameras = scene_info.train_cameras
88
+
89
+ bg_color = np.array([1,1,1]) if gs_options.white_background else np.array([0, 0, 0])
90
+ self.background = torch.tensor(bg_color, dtype=torch.float32, device='cuda')
91
+
92
+ self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
93
+ self.gaussians.training_setup(gs_options)
94
+
95
+ def getTrainCameras(self):
96
+ return self.train_cameras
97
+
98
+ def getPresetCameras(self, preset):
99
+ assert preset in self.preset_cameras
100
+ return self.preset_cameras[preset]
101
+
102
+ def run_gaussian_splatting(scene, gs_optimization_bundle):
103
+ torch.cuda.empty_cache()
104
+
105
+ return scene
106
+
107
+ from random import randint
108
+ from .gaussian_renderer import render as gs_render
109
+ from .scene.utils.loss_utils import l1_loss, ssim
110
+
111
+ pbar = tqdm(range(1, gs_options.iterations + 1))
112
+ for iteration in pbar:
113
+ scene.gaussians.update_learning_rate(iteration)
114
+
115
+ # Every 1000 its we increase the levels of SH up to a maximum degree
116
+ if iteration % 1000 == 0:
117
+ scene.gaussians.oneupSHdegree()
118
+
119
+ # Pick a random Camera
120
+ random_idx = randint(0, len(gs_optimization_bundle["frames"])-1)
121
+ viewpoint_cam = scene.getTrainCameras()[random_idx]
122
+
123
+ # Render
124
+ render_pkg = gs_render(viewpoint_cam, scene.gaussians, gs_options, scene.background)
125
+ image, viewspace_point_tensor, visibility_filter, radii = (
126
+ render_pkg['render'], render_pkg['viewspace_points'], render_pkg['visibility_filter'], render_pkg['radii'])
127
+
128
+ # Loss
129
+ gt_image = viewpoint_cam.original_image.cuda()
130
+ Ll1 = l1_loss(image, gt_image, reduce=False)
131
+ loss = (1.0 - gs_options.lambda_dssim) * Ll1
132
+
133
+ if viewpoint_cam.mask is not None:
134
+ mask = torch.from_numpy(viewpoint_cam.mask).to(loss.device)
135
+ else:
136
+ mask = 1
137
+
138
+ loss = (loss * mask).mean()
139
+ loss = loss + gs_options.lambda_dssim * (1.0 - ssim(image, gt_image))
140
+ loss.backward()
141
+
142
+ pbar.set_description(f"Loss: {loss.item():.4f}")
143
+
144
+ with torch.no_grad():
145
+ # Densification
146
+ if iteration < gs_options.densify_until_iter:
147
+ # Keep track of max radii in image-space for pruning
148
+ scene.gaussians.max_radii2D[visibility_filter] = torch.max(
149
+ scene.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
150
+ scene.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
151
+
152
+ if iteration > gs_options.densify_from_iter and iteration % gs_options.densification_interval == 0:
153
+ size_threshold = 20 if iteration > gs_options.opacity_reset_interval else None
154
+ scene.gaussians.densify_and_prune(
155
+ gs_options.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
156
+
157
+ if (iteration % gs_options.opacity_reset_interval == 0
158
+ or (gs_options.white_background and iteration == gs_options.densify_from_iter)
159
+ ):
160
+ scene.gaussians.reset_opacity()
161
+
162
+ # Optimizer step
163
+ if iteration < gs_options.iterations:
164
+ scene.gaussians.optimizer.step()
165
+ scene.gaussians.optimizer.zero_grad(set_to_none = True)
166
+
167
+ return scene
168
+
169
+ gs_options = edict({
170
+ "sh_degree": 3,
171
+ "images": "images",
172
+ "resolution": -1,
173
+ "white_background": False,
174
+ "data_device": "cuda",
175
+ "eval": False,
176
+ "use_depth": False,
177
+ "iterations": 0,#250,
178
+ "position_lr_init": 0.00016,
179
+ "position_lr_final": 0.0000016,
180
+ "position_lr_delay_mult": 0.01,
181
+ "position_lr_max_steps": 2990,
182
+ "feature_lr": 0.0,#0.0025,
183
+ "opacity_lr": 0.0,#0.05,
184
+ "scaling_lr": 0.0,#0.005,
185
+ "rotation_lr": 0.0,#0.001,
186
+ "percent_dense": 0.01,
187
+ "lambda_dssim": 0.2,
188
+ "densification_interval": 100,
189
+ "opacity_reset_interval": 3000,
190
+ "densify_from_iter": 10_000,
191
+ "densify_until_iter": 15_000,
192
+ "densify_grad_threshold": 0.0002,
193
+ "convert_SHs_python": False,
194
+ "compute_cov3D_python": False,
195
+ "debug": False,
196
+ })
utils/models.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+
8
+ from zoedepth.utils.misc import colorize
9
+ from zoedepth.utils.config import get_config
10
+ from zoedepth.models.builder import build_model
11
+ from zoedepth.models.model_io import load_wts
12
+
13
+ from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
14
+
15
+ def load_ckpt(config, model, checkpoint_dir: str = "./checkpoints", ckpt_type: str = "best"):
16
+ if hasattr(config, "checkpoint"):
17
+ checkpoint = config.checkpoint
18
+ elif hasattr(config, "ckpt_pattern"):
19
+ pattern = config.ckpt_pattern
20
+ matches = glob.glob(os.path.join(
21
+ checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
22
+ if not (len(matches) > 0):
23
+ raise ValueError(f"No matches found for the pattern {pattern}")
24
+
25
+ checkpoint = matches[0]
26
+
27
+ else:
28
+ return model
29
+ model = load_wts(model, checkpoint)
30
+ print("Loaded weights from {0}".format(checkpoint))
31
+ return model
32
+
33
+ def get_zoe_dc_model(vanilla: bool = False, ckpt_path: str = None, **kwargs):
34
+ def ZoeD_N(midas_model_type="DPT_BEiT_L_384", vanilla=False, **kwargs):
35
+ if midas_model_type != "DPT_BEiT_L_384":
36
+ raise ValueError(f"Only DPT_BEiT_L_384 MiDaS model is supported for pretrained Zoe_N model, got: {midas_model_type}")
37
+
38
+ zoedepth_config = get_config("zoedepth", "train", **kwargs)
39
+ model = build_model(zoedepth_config)
40
+
41
+ if vanilla:
42
+ model.__setattr__("vanilla", True)
43
+ return model
44
+ else:
45
+ model.__setattr__("vanilla", False)
46
+
47
+ if zoedepth_config.add_depth_channel and not vanilla:
48
+ model.core.core.pretrained.model.patch_embed.proj = torch.nn.Conv2d(
49
+ model.core.core.pretrained.model.patch_embed.proj.in_channels+2,
50
+ model.core.core.pretrained.model.patch_embed.proj.out_channels,
51
+ kernel_size=model.core.core.pretrained.model.patch_embed.proj.kernel_size,
52
+ stride=model.core.core.pretrained.model.patch_embed.proj.stride,
53
+ padding=model.core.core.pretrained.model.patch_embed.proj.padding,
54
+ bias=True)
55
+
56
+ if ckpt_path is not None:
57
+ assert os.path.exists(ckpt_path)
58
+ zoedepth_config.__setattr__("checkpoint", ckpt_path)
59
+ else:
60
+ assert vanilla, "ckpt_path must be provided for non-vanilla model"
61
+
62
+ model = load_ckpt(zoedepth_config, model)
63
+
64
+ return model
65
+
66
+ return ZoeD_N(vanilla=vanilla, ckpt_path=ckpt_path, **kwargs)
67
+
68
+ def infer_with_pad(zoe, x, pad_input: bool = True, fh: float = 3, fw: float = 3, upsampling_mode: str = "bicubic", padding_mode: str = "reflect", **kwargs):
69
+ assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim())
70
+
71
+ if pad_input:
72
+ assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0"
73
+ pad_h = int(np.sqrt(x.shape[2]/2) * fh)
74
+ pad_w = int(np.sqrt(x.shape[3]/2) * fw)
75
+ padding = [pad_w, pad_w]
76
+ if pad_h > 0:
77
+ padding += [pad_h, pad_h]
78
+
79
+ x_rgb = x[:, :3]
80
+ x_remaining = x[:, 3:]
81
+ x_rgb = F.pad(x_rgb, padding, mode=padding_mode, **kwargs)
82
+ x_remaining = F.pad(x_remaining, padding, mode="constant", value=0, **kwargs)
83
+ x = torch.cat([x_rgb, x_remaining], dim=1)
84
+ out = zoe(x)["metric_depth"]
85
+ if out.shape[-2:] != x.shape[-2:]:
86
+ out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False)
87
+ if pad_input:
88
+ # crop to the original size, handling the case where pad_h and pad_w is 0
89
+ if pad_h > 0:
90
+ out = out[:, :, pad_h:-pad_h,:]
91
+ if pad_w > 0:
92
+ out = out[:, :, :, pad_w:-pad_w]
93
+ return out
94
+
95
+ @torch.no_grad()
96
+ def infer_with_zoe_dc(zoe_dc, image, sparse_depth, scaling: float = 1):
97
+ sparse_depth_mask = (sparse_depth[None, None, ...] > 0).float()
98
+ # the metric depth range defined during training is [1e-3, 10]
99
+ x = torch.cat([image[None, ...], sparse_depth[None, None, ...] / (float(scaling) * 10.0), sparse_depth_mask], dim=1).to(zoe_dc.device)
100
+
101
+ out = infer_with_pad(zoe_dc, x)
102
+ out_flip = infer_with_pad(zoe_dc, torch.flip(x, dims=[3]))
103
+ out = (out + torch.flip(out_flip, dims=[3])) / 2
104
+
105
+ pred_depth = float(scaling) * out
106
+
107
+ return torch.nn.functional.interpolate(pred_depth, image.shape[-2:], mode='bilinear', align_corners=True)[0, 0]
108
+
109
+ def get_sd_pipeline():
110
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
111
+ "stabilityai/stable-diffusion-2-inpainting",
112
+ torch_dtype=torch.float16,
113
+ )
114
+ pipe.vae = AsymmetricAutoencoderKL.from_pretrained(
115
+ "cross-attention/asymmetric-autoencoder-kl-x-2",
116
+ torch_dtype=torch.float16
117
+ )
118
+
119
+ return pipe
utils/ops.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import skimage
4
+ from scipy import ndimage
5
+ from PIL import Image
6
+ from .models import infer_with_zoe_dc
7
+ from pytorch3d.structures import Pointclouds
8
+
9
+ import math
10
+
11
+ def nearest_neighbor_fill(img, mask, erosion=0):
12
+ img_ = np.copy(img.cpu().numpy())
13
+
14
+ if erosion > 0:
15
+ eroded_mask = skimage.morphology.binary_erosion(mask.cpu().numpy(), footprint=skimage.morphology.disk(erosion))
16
+ else:
17
+ eroded_mask = mask.cpu().numpy()
18
+
19
+ img_[eroded_mask <= 0] = np.nan
20
+
21
+ distance_to_boundary = ndimage.distance_transform_bf((~eroded_mask>0), metric="cityblock")
22
+
23
+ for current_dist in np.unique(distance_to_boundary)[1:]:
24
+ ii, jj = np.where(distance_to_boundary == current_dist)
25
+
26
+ ii_ = np.array([ii - 1, ii, ii + 1, ii - 1, ii, ii + 1, ii - 1, ii, ii + 1]).reshape(9, -1)
27
+ jj_ = np.array([jj - 1, jj - 1, jj - 1, jj, jj, jj, jj + 1, jj + 1, jj + 1]).reshape(9, -1)
28
+
29
+ ii_ = ii_.clip(0, img_.shape[0] - 1)
30
+ jj_ = jj_.clip(0, img_.shape[1] - 1)
31
+
32
+ img_[ii, jj] = np.nanmax(img_[ii_, jj_], axis=0)
33
+
34
+ return torch.from_numpy(img_).to(img.device)
35
+
36
+ def snap_high_gradients_to_nn(depth, threshold=20):
37
+ grad_depth = np.copy(depth.cpu().numpy())
38
+ grad_depth = grad_depth - grad_depth.min()
39
+ grad_depth = grad_depth / grad_depth.max()
40
+
41
+ grad = skimage.filters.rank.gradient(grad_depth, skimage.morphology.disk(1))
42
+ return nearest_neighbor_fill(depth, torch.from_numpy(grad < threshold), erosion=3)
43
+
44
+ def project_points(cameras, depth, use_pixel_centers=True):
45
+ if len(cameras) > 1:
46
+ import warnings
47
+ warnings.warn("project_points assumes only a single camera is used")
48
+
49
+ depth_t = torch.from_numpy(depth) if isinstance(depth, np.ndarray) else depth
50
+ depth_t = depth_t.to(cameras.device)
51
+
52
+ pixel_center = 0.5 if use_pixel_centers else 0
53
+
54
+ fx, fy = cameras.focal_length[0, 1], cameras.focal_length[0, 0]
55
+ cx, cy = cameras.principal_point[0, 1], cameras.principal_point[0, 0]
56
+
57
+ i, j = torch.meshgrid(
58
+ torch.arange(cameras.image_size[0][0], dtype=torch.float32, device=cameras.device) + pixel_center,
59
+ torch.arange(cameras.image_size[0][1], dtype=torch.float32, device=cameras.device) + pixel_center,
60
+ indexing="xy",
61
+ )
62
+
63
+ directions = torch.stack(
64
+ [-(i - cx) * depth_t / fx, -(j - cy) * depth_t / fy, depth_t], -1
65
+ )
66
+
67
+ xy_depth_world = cameras.get_world_to_view_transform().inverse().transform_points(directions.view(-1, 3)).unsqueeze(0)
68
+
69
+ return xy_depth_world
70
+
71
+ def get_pointcloud(xy_depth_world, device="cpu", features=None):
72
+ point_cloud = Pointclouds(points=[xy_depth_world.to(device)], features=[features] if features is not None else None)
73
+ return point_cloud
74
+
75
+ def merge_pointclouds(point_clouds):
76
+ points = torch.cat([pc.points_padded() for pc in point_clouds], dim=1)
77
+ features = torch.cat([pc.features_padded() for pc in point_clouds], dim=1)
78
+ return Pointclouds(points=[points[0]], features=[features[0]])
79
+
80
+ def outpaint_with_depth_estimation(image, mask, previous_depth, h, w, pipe, zoe_dc, prompt, cameras, dilation_size: int = 2, depth_scaling: float = 1, generator = None):
81
+ img_input = Image.fromarray((255*image[..., :3].cpu().numpy()).astype(np.uint8))
82
+
83
+ # we slightly dilate the mask as aliasing might cause us to receive a too small mask from pytorch3d
84
+ img_mask = Image.fromarray((255*skimage.morphology.isotropic_dilation(((~mask).cpu().numpy()), radius=dilation_size)).astype(np.uint8))#footprint=skimage.morphology.disk(dilation_size)))
85
+
86
+ out_image = pipe(prompt=prompt, image=img_input, mask_image=img_mask, height=h, width=w, generator=generator).images[0]
87
+ out_depth = infer_with_zoe_dc(zoe_dc, torch.from_numpy(np.asarray(out_image)/255.).permute(2,0,1).float().to(zoe_dc.device), (previous_depth * mask).to(zoe_dc.device), scaling=depth_scaling).cpu().numpy()
88
+
89
+ return out_image, out_depth
90
+
91
+ def fov2focal(fov, pixels):
92
+ return pixels / (2 * math.tan(fov / 2))
93
+
94
+ def focal2fov(focal, pixels):
95
+ return 2*math.atan(pixels/(2*focal))
utils/render.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import skimage
3
+ from pytorch3d.structures import Pointclouds
4
+ from pytorch3d.renderer import (
5
+ look_at_view_transform,
6
+ FoVOrthographicCameras,
7
+ FoVPerspectiveCameras,
8
+ PerspectiveCameras,
9
+ PointsRasterizationSettings,
10
+ PointsRenderer,
11
+ PulsarPointsRenderer,
12
+ PointsRasterizer,
13
+ AlphaCompositor,
14
+ NormWeightedCompositor
15
+ )
16
+ from .ops import nearest_neighbor_fill
17
+
18
+ from typing import cast, Optional
19
+
20
+ class PointsRendererWithMasks(PointsRenderer):
21
+ def forward(self, point_clouds, **kwargs) -> torch.Tensor:
22
+ fragments = self.rasterizer(point_clouds, **kwargs)
23
+
24
+ # Construct weights based on the distance of a point to the true point.
25
+ # However, this could be done differently: e.g. predicted as opposed
26
+ # to a function of the weights.
27
+ r = self.rasterizer.raster_settings.radius
28
+
29
+ dists2 = fragments.dists
30
+ weights = torch.ones_like(dists2)#1 - dists2 / (r * r)
31
+ ok = cast(torch.BoolTensor, (fragments.idx >= 0)).float()
32
+
33
+ weights = weights * ok
34
+
35
+ fragments_prm = fragments.idx.long().permute(0, 3, 1, 2)
36
+ weights_prm = weights.permute(0, 3, 1, 2)
37
+ images = self.compositor(
38
+ fragments_prm,
39
+ weights_prm,
40
+ point_clouds.features_packed().permute(1, 0),
41
+ **kwargs,
42
+ )
43
+
44
+ cumprod = torch.cumprod(1 - weights, dim=-1)
45
+ cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1)
46
+ depths = (weights * cumprod * fragments.zbuf).sum(dim=-1)
47
+
48
+ # permute so image comes at the end
49
+ images = images.permute(0, 2, 3, 1)
50
+ masks = fragments.idx.long()[..., 0] >= 0
51
+
52
+ return images, masks, depths
53
+
54
+ def render_with_settings(cameras, point_cloud, raster_settings, antialiasing: int = 1):
55
+ if antialiasing > 1:
56
+ raster_settings.image_size = (raster_settings.image_size[0] * antialiasing, raster_settings.image_size[1] * antialiasing)
57
+
58
+ rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
59
+
60
+ renderer = PointsRendererWithMasks(
61
+ rasterizer=rasterizer,
62
+ compositor=AlphaCompositor()
63
+ )
64
+
65
+ if antialiasing > 1:
66
+ images, masks, depths = renderer(point_cloud)
67
+
68
+ images = images.permute(0, 3, 1, 2) # NHWC -> NCHW
69
+ images = F.avg_pool2d(images, kernel_size=antialiasing, stride=antialiasing)
70
+ images = images.permute(0, 2, 3, 1) # NCHW -> NHWC
71
+
72
+ else:
73
+ return renderer(point_cloud)
74
+
75
+
76
+ def render(cameras, point_cloud, fill_point_cloud_holes: bool = False, radius: Optional[float] = None, antialiasing: int = 1):
77
+ if fill_point_cloud_holes:
78
+ coarse_raster_settings = PointsRasterizationSettings(
79
+ image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
80
+ radius = 1e-2,
81
+ points_per_pixel = 1
82
+ )
83
+
84
+ _, coarse_mask, _ = render_with_settings(cameras, point_cloud, coarse_raster_settings)
85
+
86
+ eroded_coarse_mask = torch.from_numpy(skimage.morphology.binary_erosion(coarse_mask[0].cpu().numpy(), footprint=skimage.morphology.disk(2)))
87
+
88
+ raster_settings = PointsRasterizationSettings(
89
+ image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
90
+ radius = (1 / float(max(cameras.image_size[0, 1], cameras.image_size[0, 0])) * 2.0) if radius is None else radius,
91
+ points_per_pixel = 16
92
+ )
93
+
94
+ # Render the scene
95
+ images, masks, depths = render_with_settings(cameras, point_cloud, raster_settings)
96
+
97
+ holes_in_rendering = masks[0].cpu() ^ eroded_coarse_mask
98
+
99
+ images[0] = nearest_neighbor_fill(images[0], ~holes_in_rendering, 0)
100
+ depths[0] = nearest_neighbor_fill(depths[0], ~holes_in_rendering, 0)
101
+
102
+ return images, eroded_coarse_mask.unsqueeze(0).to(masks.device), depths
103
+
104
+ else:
105
+ raster_settings = PointsRasterizationSettings(
106
+ image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
107
+ radius = (1 / float(max(cameras.image_size[0, 1], cameras.image_size[0, 0])) * 2.0) if radius is None else radius,
108
+ points_per_pixel = 16
109
+ )
110
+
111
+ # Render the scene
112
+ return render_with_settings(cameras, point_cloud, raster_settings)
utils/scene/__init__.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import os
13
+ import random
14
+ import json
15
+ from .utils.system_utils import searchForMaxIteration
16
+ from .dataset_readers import sceneLoadTypeCallbacks
17
+ from .gaussian_model import GaussianModel
18
+ from .utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
19
+
20
+ class Scene:
21
+
22
+ gaussians : GaussianModel
23
+
24
+ def __init__(self, args, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
25
+ """b
26
+ :param path: Path to colmap scene main folder.
27
+ """
28
+ self.model_path = args.model_path
29
+ self.loaded_iter = None
30
+ self.gaussians = gaussians
31
+
32
+ if load_iteration:
33
+ if load_iteration == -1:
34
+ self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
35
+ else:
36
+ self.loaded_iter = load_iteration
37
+ print("Loading trained model at iteration {}".format(self.loaded_iter))
38
+
39
+ self.train_cameras = {}
40
+ self.test_cameras = {}
41
+
42
+ if os.path.exists(os.path.join(args.source_path, "sparse")):
43
+ scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
44
+ elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
45
+ print("Found transforms_train.json file, assuming Blender data set!")
46
+ scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
47
+ else:
48
+ assert False, "Could not recognize scene type!"
49
+
50
+ if not self.loaded_iter:
51
+ with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
52
+ dest_file.write(src_file.read())
53
+ json_cams = []
54
+ camlist = []
55
+ if scene_info.test_cameras:
56
+ camlist.extend(scene_info.test_cameras)
57
+ if scene_info.train_cameras:
58
+ camlist.extend(scene_info.train_cameras)
59
+ for id, cam in enumerate(camlist):
60
+ json_cams.append(camera_to_JSON(id, cam))
61
+ with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
62
+ json.dump(json_cams, file)
63
+
64
+ if shuffle:
65
+ random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
66
+ random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
67
+
68
+ self.cameras_extent = scene_info.nerf_normalization["radius"]
69
+
70
+ for resolution_scale in resolution_scales:
71
+ print("Loading Training Cameras")
72
+ self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
73
+ print("Loading Test Cameras")
74
+ self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
75
+
76
+ if self.loaded_iter:
77
+ self.gaussians.load_ply(os.path.join(self.model_path,
78
+ "point_cloud",
79
+ "iteration_" + str(self.loaded_iter),
80
+ "point_cloud.ply"))
81
+ else:
82
+ self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
83
+
84
+ def save(self, iteration):
85
+ point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
86
+ self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
87
+
88
+ def getTrainCameras(self, scale=1.0):
89
+ return self.train_cameras[scale]
90
+
91
+ def getTestCameras(self, scale=1.0):
92
+ return self.test_cameras[scale]
utils/scene/cameras.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ from torch import nn
15
+ import numpy as np
16
+ from .utils.graphics_utils import getWorld2View2, getProjectionMatrix
17
+
18
+ class Camera(nn.Module):
19
+ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
20
+ image_name, uid, crop_box=None, mask=None,
21
+ trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
22
+ ):
23
+ super(Camera, self).__init__()
24
+
25
+ self.uid = uid
26
+ self.colmap_id = colmap_id
27
+ self.R = R
28
+ self.T = T
29
+ self.FoVx = FoVx
30
+ self.FoVy = FoVy
31
+ self.image_name = image_name
32
+ self.crop_box = crop_box
33
+ self.mask = mask
34
+
35
+ try:
36
+ self.data_device = torch.device(data_device)
37
+ except Exception as e:
38
+ print(e)
39
+ print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
40
+ self.data_device = torch.device("cuda")
41
+
42
+ self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
43
+ self.image_width = self.original_image.shape[2]
44
+ self.image_height = self.original_image.shape[1]
45
+
46
+ self.gt_alpha_mask = gt_alpha_mask
47
+
48
+ #if gt_alpha_mask is not None:
49
+ # self.original_image *= gt_alpha_mask.to(self.data_device)
50
+ #else:
51
+ # self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
52
+
53
+ self.zfar = 100.0
54
+ self.znear = 0.01
55
+
56
+ self.trans = trans
57
+ self.scale = scale
58
+
59
+ self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
60
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy, crop_box=self.crop_box, width=self.image_width, height=self.image_height).transpose(0,1).cuda()
61
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
62
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
63
+
64
+ class MiniCam:
65
+ def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
66
+ self.image_width = width
67
+ self.image_height = height
68
+ self.FoVy = fovy
69
+ self.FoVx = fovx
70
+ self.znear = znear
71
+ self.zfar = zfar
72
+ self.world_view_transform = world_view_transform
73
+ self.full_proj_transform = full_proj_transform
74
+ view_inv = torch.inverse(self.world_view_transform)
75
+ self.camera_center = view_inv[3][:3]
76
+
utils/scene/colmap_loader.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import numpy as np
13
+ import collections
14
+ import struct
15
+
16
+ CameraModel = collections.namedtuple(
17
+ "CameraModel", ["model_id", "model_name", "num_params"])
18
+ Camera = collections.namedtuple(
19
+ "Camera", ["id", "model", "width", "height", "params"])
20
+ BaseImage = collections.namedtuple(
21
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22
+ Point3D = collections.namedtuple(
23
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24
+ CAMERA_MODELS = {
25
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
33
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36
+ }
37
+ CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38
+ for camera_model in CAMERA_MODELS])
39
+ CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40
+ for camera_model in CAMERA_MODELS])
41
+
42
+
43
+ def qvec2rotmat(qvec):
44
+ return np.array([
45
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54
+
55
+ def rotmat2qvec(R):
56
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
57
+ K = np.array([
58
+ [Rxx - Ryy - Rzz, 0, 0, 0],
59
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
60
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
61
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
62
+ eigvals, eigvecs = np.linalg.eigh(K)
63
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
64
+ if qvec[0] < 0:
65
+ qvec *= -1
66
+ return qvec
67
+
68
+ class Image(BaseImage):
69
+ def qvec2rotmat(self):
70
+ return qvec2rotmat(self.qvec)
71
+
72
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
73
+ """Read and unpack the next bytes from a binary file.
74
+ :param fid:
75
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
76
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
77
+ :param endian_character: Any of {@, =, <, >, !}
78
+ :return: Tuple of read and unpacked values.
79
+ """
80
+ data = fid.read(num_bytes)
81
+ return struct.unpack(endian_character + format_char_sequence, data)
82
+
83
+ def read_points3D_text(path):
84
+ """
85
+ see: src/base/reconstruction.cc
86
+ void Reconstruction::ReadPoints3DText(const std::string& path)
87
+ void Reconstruction::WritePoints3DText(const std::string& path)
88
+ """
89
+ xyzs = None
90
+ rgbs = None
91
+ errors = None
92
+ num_points = 0
93
+ with open(path, "r") as fid:
94
+ while True:
95
+ line = fid.readline()
96
+ if not line:
97
+ break
98
+ line = line.strip()
99
+ if len(line) > 0 and line[0] != "#":
100
+ num_points += 1
101
+
102
+
103
+ xyzs = np.empty((num_points, 3))
104
+ rgbs = np.empty((num_points, 3))
105
+ errors = np.empty((num_points, 1))
106
+ count = 0
107
+ with open(path, "r") as fid:
108
+ while True:
109
+ line = fid.readline()
110
+ if not line:
111
+ break
112
+ line = line.strip()
113
+ if len(line) > 0 and line[0] != "#":
114
+ elems = line.split()
115
+ xyz = np.array(tuple(map(float, elems[1:4])))
116
+ rgb = np.array(tuple(map(int, elems[4:7])))
117
+ error = np.array(float(elems[7]))
118
+ xyzs[count] = xyz
119
+ rgbs[count] = rgb
120
+ errors[count] = error
121
+ count += 1
122
+
123
+ return xyzs, rgbs, errors
124
+
125
+ def read_points3D_binary(path_to_model_file):
126
+ """
127
+ see: src/base/reconstruction.cc
128
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
129
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
130
+ """
131
+
132
+
133
+ with open(path_to_model_file, "rb") as fid:
134
+ num_points = read_next_bytes(fid, 8, "Q")[0]
135
+
136
+ xyzs = np.empty((num_points, 3))
137
+ rgbs = np.empty((num_points, 3))
138
+ errors = np.empty((num_points, 1))
139
+
140
+ for p_id in range(num_points):
141
+ binary_point_line_properties = read_next_bytes(
142
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
143
+ xyz = np.array(binary_point_line_properties[1:4])
144
+ rgb = np.array(binary_point_line_properties[4:7])
145
+ error = np.array(binary_point_line_properties[7])
146
+ track_length = read_next_bytes(
147
+ fid, num_bytes=8, format_char_sequence="Q")[0]
148
+ track_elems = read_next_bytes(
149
+ fid, num_bytes=8*track_length,
150
+ format_char_sequence="ii"*track_length)
151
+ xyzs[p_id] = xyz
152
+ rgbs[p_id] = rgb
153
+ errors[p_id] = error
154
+ return xyzs, rgbs, errors
155
+
156
+ def read_intrinsics_text(path):
157
+ """
158
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
159
+ """
160
+ cameras = {}
161
+ with open(path, "r") as fid:
162
+ while True:
163
+ line = fid.readline()
164
+ if not line:
165
+ break
166
+ line = line.strip()
167
+ if len(line) > 0 and line[0] != "#":
168
+ elems = line.split()
169
+ camera_id = int(elems[0])
170
+ model = elems[1]
171
+ assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
172
+ width = int(elems[2])
173
+ height = int(elems[3])
174
+ params = np.array(tuple(map(float, elems[4:])))
175
+ cameras[camera_id] = Camera(id=camera_id, model=model,
176
+ width=width, height=height,
177
+ params=params)
178
+ return cameras
179
+
180
+ def read_extrinsics_binary(path_to_model_file):
181
+ """
182
+ see: src/base/reconstruction.cc
183
+ void Reconstruction::ReadImagesBinary(const std::string& path)
184
+ void Reconstruction::WriteImagesBinary(const std::string& path)
185
+ """
186
+ images = {}
187
+ with open(path_to_model_file, "rb") as fid:
188
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
189
+ for _ in range(num_reg_images):
190
+ binary_image_properties = read_next_bytes(
191
+ fid, num_bytes=64, format_char_sequence="idddddddi")
192
+ image_id = binary_image_properties[0]
193
+ qvec = np.array(binary_image_properties[1:5])
194
+ tvec = np.array(binary_image_properties[5:8])
195
+ camera_id = binary_image_properties[8]
196
+ image_name = ""
197
+ current_char = read_next_bytes(fid, 1, "c")[0]
198
+ while current_char != b"\x00": # look for the ASCII 0 entry
199
+ image_name += current_char.decode("utf-8")
200
+ current_char = read_next_bytes(fid, 1, "c")[0]
201
+ num_points2D = read_next_bytes(fid, num_bytes=8,
202
+ format_char_sequence="Q")[0]
203
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
204
+ format_char_sequence="ddq"*num_points2D)
205
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
206
+ tuple(map(float, x_y_id_s[1::3]))])
207
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
208
+ images[image_id] = Image(
209
+ id=image_id, qvec=qvec, tvec=tvec,
210
+ camera_id=camera_id, name=image_name,
211
+ xys=xys, point3D_ids=point3D_ids)
212
+ return images
213
+
214
+
215
+ def read_intrinsics_binary(path_to_model_file):
216
+ """
217
+ see: src/base/reconstruction.cc
218
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
219
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
220
+ """
221
+ cameras = {}
222
+ with open(path_to_model_file, "rb") as fid:
223
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
224
+ for _ in range(num_cameras):
225
+ camera_properties = read_next_bytes(
226
+ fid, num_bytes=24, format_char_sequence="iiQQ")
227
+ camera_id = camera_properties[0]
228
+ model_id = camera_properties[1]
229
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
230
+ width = camera_properties[2]
231
+ height = camera_properties[3]
232
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
233
+ params = read_next_bytes(fid, num_bytes=8*num_params,
234
+ format_char_sequence="d"*num_params)
235
+ cameras[camera_id] = Camera(id=camera_id,
236
+ model=model_name,
237
+ width=width,
238
+ height=height,
239
+ params=np.array(params))
240
+ assert len(cameras) == num_cameras
241
+ return cameras
242
+
243
+
244
+ def read_extrinsics_text(path):
245
+ """
246
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
247
+ """
248
+ images = {}
249
+ with open(path, "r") as fid:
250
+ while True:
251
+ line = fid.readline()
252
+ if not line:
253
+ break
254
+ line = line.strip()
255
+ if len(line) > 0 and line[0] != "#":
256
+ elems = line.split()
257
+ image_id = int(elems[0])
258
+ qvec = np.array(tuple(map(float, elems[1:5])))
259
+ tvec = np.array(tuple(map(float, elems[5:8])))
260
+ camera_id = int(elems[8])
261
+ image_name = elems[9]
262
+ elems = fid.readline().split()
263
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
264
+ tuple(map(float, elems[1::3]))])
265
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
266
+ images[image_id] = Image(
267
+ id=image_id, qvec=qvec, tvec=tvec,
268
+ camera_id=camera_id, name=image_name,
269
+ xys=xys, point3D_ids=point3D_ids)
270
+ return images
271
+
272
+
273
+ def read_colmap_bin_array(path):
274
+ """
275
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
276
+
277
+ :param path: path to the colmap binary file.
278
+ :return: nd array with the floating point values in the value
279
+ """
280
+ with open(path, "rb") as fid:
281
+ width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
282
+ usecols=(0, 1, 2), dtype=int)
283
+ fid.seek(0)
284
+ num_delimiter = 0
285
+ byte = fid.read(1)
286
+ while True:
287
+ if byte == b"&":
288
+ num_delimiter += 1
289
+ if num_delimiter >= 3:
290
+ break
291
+ byte = fid.read(1)
292
+ array = np.fromfile(fid, np.float32)
293
+ array = array.reshape((width, height, channels), order="F")
294
+ return np.transpose(array, (1, 0, 2)).squeeze()
utils/scene/dataset_readers.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import os
13
+ import sys
14
+ from PIL import Image
15
+ from typing import NamedTuple
16
+ from .colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
17
+ read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
18
+ from .utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
19
+ import numpy as np
20
+ import json
21
+ from pathlib import Path
22
+ from plyfile import PlyData, PlyElement
23
+ from .utils.sh_utils import SH2RGB
24
+ from .gaussian_model import BasicPointCloud
25
+
26
+ class CameraInfo(NamedTuple):
27
+ uid: int
28
+ R: np.array
29
+ T: np.array
30
+ FovY: np.array
31
+ FovX: np.array
32
+ image: np.array
33
+ image_path: str
34
+ image_name: str
35
+ mask: np.array
36
+ mask_path: str
37
+ width: int
38
+ height: int
39
+
40
+ class SceneInfo(NamedTuple):
41
+ point_cloud: BasicPointCloud
42
+ train_cameras: list
43
+ test_cameras: list
44
+ nerf_normalization: dict
45
+ ply_path: str
46
+
47
+ def getNerfppNorm(cam_info):
48
+ def get_center_and_diag(cam_centers):
49
+ cam_centers = np.hstack(cam_centers)
50
+ avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
51
+ center = avg_cam_center
52
+ dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
53
+ diagonal = np.max(dist)
54
+ return center.flatten(), diagonal
55
+
56
+ cam_centers = []
57
+
58
+ for cam in cam_info:
59
+ W2C = getWorld2View2(cam.R, cam.T)
60
+ C2W = np.linalg.inv(W2C)
61
+ cam_centers.append(C2W[:3, 3:4])
62
+
63
+ center, diagonal = get_center_and_diag(cam_centers)
64
+ radius = diagonal * 1.1
65
+
66
+ translate = -center
67
+
68
+ return {"translate": translate, "radius": radius}
69
+
70
+ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, masks_folder):
71
+ cam_infos = []
72
+ for idx, key in enumerate(cam_extrinsics):
73
+ sys.stdout.write('\r')
74
+ # the exact output you're looking for:
75
+ sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
76
+ sys.stdout.flush()
77
+
78
+ extr = cam_extrinsics[key]
79
+ intr = cam_intrinsics[extr.camera_id]
80
+ height = intr.height
81
+ width = intr.width
82
+
83
+ uid = intr.id
84
+ R = np.transpose(qvec2rotmat(extr.qvec))
85
+ T = np.array(extr.tvec)
86
+
87
+ if intr.model=="SIMPLE_PINHOLE":
88
+ focal_length_x = intr.params[0]
89
+ FovY = focal2fov(focal_length_x, height)
90
+ FovX = focal2fov(focal_length_x, width)
91
+ elif intr.model=="PINHOLE":
92
+ focal_length_x = intr.params[0]
93
+ focal_length_y = intr.params[1]
94
+ FovY = focal2fov(focal_length_y, height)
95
+ FovX = focal2fov(focal_length_x, width)
96
+ else:
97
+ assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
98
+
99
+ image_path = os.path.join(images_folder, os.path.basename(extr.name))
100
+ image_name = os.path.basename(image_path).split(".")[0]
101
+ image = Image.open(image_path)
102
+
103
+ mask_path = os.path.join(masks_folder, os.path.basename(extr.name).replace(".jpg", ".png"))
104
+ try:
105
+ mask = Image.open(mask_path)
106
+ except:
107
+ mask = None
108
+
109
+ cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, mask=mask, mask_path=mask_path,
110
+ image_path=image_path, image_name=image_name, width=width, height=height)
111
+ cam_infos.append(cam_info)
112
+ sys.stdout.write('\n')
113
+ return cam_infos
114
+
115
+ def fetchPly(path):
116
+ plydata = PlyData.read(path)
117
+ vertices = plydata['vertex']
118
+ positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
119
+ colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
120
+ normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
121
+ return BasicPointCloud(points=positions, colors=colors, normals=normals)
122
+
123
+ def storePly(path, xyz, rgb):
124
+ # Define the dtype for the structured array
125
+ dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
126
+ ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
127
+ ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
128
+
129
+ normals = np.zeros_like(xyz)
130
+
131
+ elements = np.empty(xyz.shape[0], dtype=dtype)
132
+ attributes = np.concatenate((xyz, normals, rgb), axis=1)
133
+ elements[:] = list(map(tuple, attributes))
134
+
135
+ # Create the PlyData object and write to file
136
+ vertex_element = PlyElement.describe(elements, 'vertex')
137
+ ply_data = PlyData([vertex_element])
138
+ ply_data.write(path)
139
+
140
+ def readColmapSceneInfo(path, images, eval, llffhold=8):
141
+ try:
142
+ cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
143
+ cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
144
+ cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
145
+ cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
146
+ except:
147
+ cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
148
+ cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
149
+ cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
150
+ cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
151
+
152
+ reading_dir = "images" if images == None else images
153
+ # FIXME in post
154
+ mask_reading_dir = "masks"# if images == None else images
155
+ cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), masks_folder=os.path.join(path, mask_reading_dir))
156
+ cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
157
+
158
+ if eval:
159
+ train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
160
+ test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
161
+ else:
162
+ train_cam_infos = cam_infos
163
+ test_cam_infos = []
164
+
165
+ nerf_normalization = getNerfppNorm(train_cam_infos)
166
+
167
+ ply_path = os.path.join(path, "sparse/0/points3D.ply")
168
+ bin_path = os.path.join(path, "sparse/0/points3D.bin")
169
+ txt_path = os.path.join(path, "sparse/0/points3D.txt")
170
+ if not os.path.exists(ply_path):
171
+ print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
172
+ try:
173
+ xyz, rgb, _ = read_points3D_binary(bin_path)
174
+ except:
175
+ xyz, rgb, _ = read_points3D_text(txt_path)
176
+ storePly(ply_path, xyz, rgb)
177
+ try:
178
+ pcd = fetchPly(ply_path)
179
+ except:
180
+ pcd = None
181
+
182
+ scene_info = SceneInfo(point_cloud=pcd,
183
+ train_cameras=train_cam_infos,
184
+ test_cameras=test_cam_infos,
185
+ nerf_normalization=nerf_normalization,
186
+ ply_path=ply_path)
187
+ return scene_info
188
+
189
+ def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
190
+ cam_infos = []
191
+
192
+ with open(os.path.join(path, transformsfile)) as json_file:
193
+ contents = json.load(json_file)
194
+ fovx = contents["camera_angle_x"]
195
+
196
+ frames = contents["frames"]
197
+ for idx, frame in enumerate(frames):
198
+ cam_name = os.path.join(path, frame["file_path"] + extension)
199
+
200
+ # NeRF 'transform_matrix' is a camera-to-world transform
201
+ c2w = np.array(frame["transform_matrix"])
202
+ # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
203
+ c2w[:3, 1:3] *= -1
204
+
205
+ # get the world-to-camera transform and set R, T
206
+ w2c = np.linalg.inv(c2w)
207
+ R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
208
+ T = w2c[:3, 3]
209
+
210
+ image_path = os.path.join(path, cam_name)
211
+ image_name = Path(cam_name).stem
212
+ image = Image.open(image_path)
213
+
214
+ im_data = np.array(image.convert("RGBA"))
215
+
216
+ bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
217
+
218
+ norm_data = im_data / 255.0
219
+ arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
220
+ image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
221
+
222
+ fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
223
+ FovY = fovy
224
+ FovX = fovx
225
+
226
+ cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
227
+ image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
228
+
229
+ return cam_infos
230
+
231
+ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
232
+ print("Reading Training Transforms")
233
+ train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
234
+ print("Reading Test Transforms")
235
+ test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
236
+
237
+ if not eval:
238
+ train_cam_infos.extend(test_cam_infos)
239
+ test_cam_infos = []
240
+
241
+ nerf_normalization = getNerfppNorm(train_cam_infos)
242
+
243
+ ply_path = os.path.join(path, "points3d.ply")
244
+ if not os.path.exists(ply_path):
245
+ # Since this data set has no colmap data, we start with random points
246
+ num_pts = 100_000
247
+ print(f"Generating random point cloud ({num_pts})...")
248
+
249
+ # We create random points inside the bounds of the synthetic Blender scenes
250
+ xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
251
+ shs = np.random.random((num_pts, 3)) / 255.0
252
+ pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
253
+
254
+ storePly(ply_path, xyz, SH2RGB(shs) * 255)
255
+ try:
256
+ pcd = fetchPly(ply_path)
257
+ except:
258
+ pcd = None
259
+
260
+ scene_info = SceneInfo(point_cloud=pcd,
261
+ train_cameras=train_cam_infos,
262
+ test_cameras=test_cam_infos,
263
+ nerf_normalization=nerf_normalization,
264
+ ply_path=ply_path)
265
+ return scene_info
266
+
267
+ sceneLoadTypeCallbacks = {
268
+ "Colmap": readColmapSceneInfo,
269
+ "Blender" : readNerfSyntheticInfo
270
+ }
utils/scene/gaussian_model.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import numpy as np
14
+ from .utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
15
+ from torch import nn
16
+ import os
17
+ from .utils.system_utils import mkdir_p
18
+ from plyfile import PlyData, PlyElement
19
+ from .utils.sh_utils import RGB2SH
20
+ from .utils.graphics_utils import BasicPointCloud
21
+ from .utils.general_utils import strip_symmetric, build_scaling_rotation
22
+
23
+ from scipy.spatial import KDTree
24
+
25
+ # credit to https://github.com/graphdeco-inria/gaussian-splatting/issues/292#issuecomment-2007934451
26
+ def distCUDA2(points):
27
+ points_np = points.detach().cpu().float().numpy()
28
+ dists, inds = KDTree(points_np).query(points_np, k=4)
29
+ meanDists = (dists[:, 1:] ** 2).mean(1)
30
+
31
+ return torch.tensor(meanDists, dtype=points.dtype, device=points.device)
32
+
33
+ class GaussianModel:
34
+
35
+ def setup_functions(self):
36
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
37
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
38
+ actual_covariance = L @ L.transpose(1, 2)
39
+ symm = strip_symmetric(actual_covariance)
40
+ return symm
41
+
42
+ self.scaling_activation = torch.exp
43
+ self.scaling_inverse_activation = torch.log
44
+
45
+ self.covariance_activation = build_covariance_from_scaling_rotation
46
+
47
+ self.opacity_activation = torch.sigmoid
48
+ self.inverse_opacity_activation = inverse_sigmoid
49
+
50
+ self.rotation_activation = torch.nn.functional.normalize
51
+
52
+
53
+ def __init__(self, sh_degree : int):
54
+ self.active_sh_degree = 0
55
+ self.max_sh_degree = sh_degree
56
+ self._xyz = torch.empty(0)
57
+ self._features_dc = torch.empty(0)
58
+ self._features_rest = torch.empty(0)
59
+ self._scaling = torch.empty(0)
60
+ self._rotation = torch.empty(0)
61
+ self._opacity = torch.empty(0)
62
+ self.max_radii2D = torch.empty(0)
63
+ self.xyz_gradient_accum = torch.empty(0)
64
+ self.denom = torch.empty(0)
65
+ self.optimizer = None
66
+ self.percent_dense = 0
67
+ self.spatial_lr_scale = 0
68
+ self.setup_functions()
69
+
70
+ def capture(self):
71
+ return (
72
+ self.active_sh_degree,
73
+ self._xyz,
74
+ self._features_dc,
75
+ self._features_rest,
76
+ self._scaling,
77
+ self._rotation,
78
+ self._opacity,
79
+ self.max_radii2D,
80
+ self.xyz_gradient_accum,
81
+ self.denom,
82
+ self.optimizer.state_dict(),
83
+ self.spatial_lr_scale,
84
+ )
85
+
86
+ def restore(self, model_args, training_args):
87
+ (self.active_sh_degree,
88
+ self._xyz,
89
+ self._features_dc,
90
+ self._features_rest,
91
+ self._scaling,
92
+ self._rotation,
93
+ self._opacity,
94
+ self.max_radii2D,
95
+ xyz_gradient_accum,
96
+ denom,
97
+ opt_dict,
98
+ self.spatial_lr_scale) = model_args
99
+ self.training_setup(training_args)
100
+ self.xyz_gradient_accum = xyz_gradient_accum
101
+ self.denom = denom
102
+ self.optimizer.load_state_dict(opt_dict)
103
+
104
+ @property
105
+ def get_scaling(self):
106
+ return self.scaling_activation(self._scaling)
107
+
108
+ @property
109
+ def get_rotation(self):
110
+ return self.rotation_activation(self._rotation)
111
+
112
+ @property
113
+ def get_xyz(self):
114
+ return self._xyz
115
+
116
+ @property
117
+ def get_features(self):
118
+ features_dc = self._features_dc
119
+ features_rest = self._features_rest
120
+ return torch.cat((features_dc, features_rest), dim=1)
121
+
122
+ @property
123
+ def get_opacity(self):
124
+ return self.opacity_activation(self._opacity)
125
+
126
+ def get_covariance(self, scaling_modifier = 1):
127
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
128
+
129
+ def oneupSHdegree(self):
130
+ if self.active_sh_degree < self.max_sh_degree:
131
+ self.active_sh_degree += 1
132
+
133
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
134
+ self.spatial_lr_scale = spatial_lr_scale
135
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
136
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
137
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
138
+ features[:, :3, 0 ] = fused_color
139
+ features[:, 3:, 1:] = 0.0
140
+
141
+ print("Number of points at initialisation : ", fused_point_cloud.shape[0])
142
+
143
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
144
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
145
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
146
+ rots[:, 0] = 1
147
+
148
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
149
+
150
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
151
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
152
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
153
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
154
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
155
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
156
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
157
+
158
+ def training_setup(self, training_args):
159
+ self.percent_dense = training_args.percent_dense
160
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
161
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
162
+
163
+ l = [
164
+ {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
165
+ {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
166
+ {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
167
+ {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
168
+ {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
169
+ {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
170
+ ]
171
+
172
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
173
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
174
+ lr_final=training_args.position_lr_final*self.spatial_lr_scale,
175
+ lr_delay_mult=training_args.position_lr_delay_mult,
176
+ max_steps=training_args.position_lr_max_steps)
177
+
178
+ def update_learning_rate(self, iteration):
179
+ ''' Learning rate scheduling per step '''
180
+ for param_group in self.optimizer.param_groups:
181
+ if param_group["name"] == "xyz":
182
+ lr = self.xyz_scheduler_args(iteration)
183
+ param_group['lr'] = lr
184
+ return lr
185
+
186
+ def construct_list_of_attributes(self):
187
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
188
+ # All channels except the 3 DC
189
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
190
+ l.append('f_dc_{}'.format(i))
191
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
192
+ l.append('f_rest_{}'.format(i))
193
+ l.append('opacity')
194
+ for i in range(self._scaling.shape[1]):
195
+ l.append('scale_{}'.format(i))
196
+ for i in range(self._rotation.shape[1]):
197
+ l.append('rot_{}'.format(i))
198
+ return l
199
+
200
+ def save_ply(self, path):
201
+ mkdir_p(os.path.dirname(path))
202
+
203
+ xyz = self._xyz.detach().cpu().numpy()
204
+ normals = np.zeros_like(xyz)
205
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
206
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
207
+ opacities = self._opacity.detach().cpu().numpy()
208
+ scale = self._scaling.detach().cpu().numpy()
209
+ rotation = self._rotation.detach().cpu().numpy()
210
+
211
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
212
+
213
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
214
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
215
+ elements[:] = list(map(tuple, attributes))
216
+ el = PlyElement.describe(elements, 'vertex')
217
+ PlyData([el]).write(path)
218
+
219
+ def reset_opacity(self):
220
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
221
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
222
+ self._opacity = optimizable_tensors["opacity"]
223
+
224
+ def load_ply(self, path):
225
+ plydata = PlyData.read(path)
226
+
227
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
228
+ np.asarray(plydata.elements[0]["y"]),
229
+ np.asarray(plydata.elements[0]["z"])), axis=1)
230
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
231
+
232
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
233
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
234
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
235
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
236
+
237
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
238
+ extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
239
+ assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
240
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
241
+ for idx, attr_name in enumerate(extra_f_names):
242
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
243
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
244
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
245
+
246
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
247
+ scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
248
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
249
+ for idx, attr_name in enumerate(scale_names):
250
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
251
+
252
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
253
+ rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
254
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
255
+ for idx, attr_name in enumerate(rot_names):
256
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
257
+
258
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
259
+ self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
260
+ self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
261
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
262
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
263
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
264
+
265
+ self.active_sh_degree = self.max_sh_degree
266
+
267
+ def replace_tensor_to_optimizer(self, tensor, name):
268
+ optimizable_tensors = {}
269
+ for group in self.optimizer.param_groups:
270
+ if group["name"] == name:
271
+ stored_state = self.optimizer.state.get(group['params'][0], None)
272
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
273
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
274
+
275
+ del self.optimizer.state[group['params'][0]]
276
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
277
+ self.optimizer.state[group['params'][0]] = stored_state
278
+
279
+ optimizable_tensors[group["name"]] = group["params"][0]
280
+ return optimizable_tensors
281
+
282
+ def _prune_optimizer(self, mask):
283
+ optimizable_tensors = {}
284
+ for group in self.optimizer.param_groups:
285
+ stored_state = self.optimizer.state.get(group['params'][0], None)
286
+ if stored_state is not None:
287
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
288
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
289
+
290
+ del self.optimizer.state[group['params'][0]]
291
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
292
+ self.optimizer.state[group['params'][0]] = stored_state
293
+
294
+ optimizable_tensors[group["name"]] = group["params"][0]
295
+ else:
296
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
297
+ optimizable_tensors[group["name"]] = group["params"][0]
298
+ return optimizable_tensors
299
+
300
+ def prune_points(self, mask):
301
+ valid_points_mask = ~mask
302
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
303
+
304
+ self._xyz = optimizable_tensors["xyz"]
305
+ self._features_dc = optimizable_tensors["f_dc"]
306
+ self._features_rest = optimizable_tensors["f_rest"]
307
+ self._opacity = optimizable_tensors["opacity"]
308
+ self._scaling = optimizable_tensors["scaling"]
309
+ self._rotation = optimizable_tensors["rotation"]
310
+
311
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
312
+
313
+ self.denom = self.denom[valid_points_mask]
314
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
315
+
316
+ def cat_tensors_to_optimizer(self, tensors_dict):
317
+ optimizable_tensors = {}
318
+ for group in self.optimizer.param_groups:
319
+ assert len(group["params"]) == 1
320
+ extension_tensor = tensors_dict[group["name"]]
321
+ stored_state = self.optimizer.state.get(group['params'][0], None)
322
+ if stored_state is not None:
323
+
324
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
325
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
326
+
327
+ del self.optimizer.state[group['params'][0]]
328
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
329
+ self.optimizer.state[group['params'][0]] = stored_state
330
+
331
+ optimizable_tensors[group["name"]] = group["params"][0]
332
+ else:
333
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
334
+ optimizable_tensors[group["name"]] = group["params"][0]
335
+
336
+ return optimizable_tensors
337
+
338
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
339
+ d = {"xyz": new_xyz,
340
+ "f_dc": new_features_dc,
341
+ "f_rest": new_features_rest,
342
+ "opacity": new_opacities,
343
+ "scaling" : new_scaling,
344
+ "rotation" : new_rotation}
345
+
346
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
347
+ self._xyz = optimizable_tensors["xyz"]
348
+ self._features_dc = optimizable_tensors["f_dc"]
349
+ self._features_rest = optimizable_tensors["f_rest"]
350
+ self._opacity = optimizable_tensors["opacity"]
351
+ self._scaling = optimizable_tensors["scaling"]
352
+ self._rotation = optimizable_tensors["rotation"]
353
+
354
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
355
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
356
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
357
+
358
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
359
+ n_init_points = self.get_xyz.shape[0]
360
+ # Extract points that satisfy the gradient condition
361
+ padded_grad = torch.zeros((n_init_points), device="cuda")
362
+ padded_grad[:grads.shape[0]] = grads.squeeze()
363
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
364
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
365
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
366
+
367
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
368
+ means =torch.zeros((stds.size(0), 3),device="cuda")
369
+ samples = torch.normal(mean=means, std=stds)
370
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
371
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
372
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
373
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
374
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
375
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
376
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
377
+
378
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
379
+
380
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
381
+ self.prune_points(prune_filter)
382
+
383
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
384
+ # Extract points that satisfy the gradient condition
385
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
386
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
387
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
388
+
389
+ new_xyz = self._xyz[selected_pts_mask]
390
+ new_features_dc = self._features_dc[selected_pts_mask]
391
+ new_features_rest = self._features_rest[selected_pts_mask]
392
+ new_opacities = self._opacity[selected_pts_mask]
393
+ new_scaling = self._scaling[selected_pts_mask]
394
+ new_rotation = self._rotation[selected_pts_mask]
395
+
396
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
397
+
398
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
399
+ grads = self.xyz_gradient_accum / self.denom
400
+ grads[grads.isnan()] = 0.0
401
+
402
+ self.densify_and_clone(grads, max_grad, extent)
403
+ self.densify_and_split(grads, max_grad, extent)
404
+
405
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
406
+ if max_screen_size:
407
+ big_points_vs = self.max_radii2D > max_screen_size
408
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
409
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
410
+ self.prune_points(prune_mask)
411
+
412
+ torch.cuda.empty_cache()
413
+
414
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
415
+ self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
416
+ self.denom[update_filter] += 1
utils/scene/utils/camera_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from ..cameras import Camera
13
+ import numpy as np
14
+ from .general_utils import PILtoTorch
15
+ from .graphics_utils import fov2focal
16
+
17
+ WARNED = False
18
+
19
+ def loadCam(args, id, cam_info, resolution_scale):
20
+ orig_w, orig_h = cam_info.image.size
21
+
22
+ if args.resolution in [1, 2, 4, 8]:
23
+ resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
24
+ else: # should be a type that converts to float
25
+ if args.resolution == -1:
26
+ if orig_w > 1600:
27
+ global WARNED
28
+ if not WARNED:
29
+ print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
30
+ "If this is not desired, please explicitly specify '--resolution/-r' as 1")
31
+ WARNED = True
32
+ global_down = orig_w / 1600
33
+ else:
34
+ global_down = 1
35
+ else:
36
+ global_down = orig_w / args.resolution
37
+
38
+ scale = float(global_down) * float(resolution_scale)
39
+ resolution = (int(orig_w / scale), int(orig_h / scale))
40
+
41
+ resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42
+
43
+ gt_image = resized_image_rgb[:3, ...]
44
+ loaded_mask = None
45
+
46
+ if resized_image_rgb.shape[1] == 4:
47
+ loaded_mask = resized_image_rgb[3:4, ...]
48
+ elif cam_info.mask is not None:
49
+ loaded_mask = ~(PILtoTorch(cam_info.mask, resolution)[0:1, ...] > 0)
50
+
51
+ return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
52
+ FoVx=cam_info.FovX, FoVy=cam_info.FovY,
53
+ image=gt_image, gt_alpha_mask=loaded_mask,
54
+ image_name=cam_info.image_name, uid=id, data_device=args.data_device)
55
+
56
+ def cameraList_from_camInfos(cam_infos, resolution_scale, args):
57
+ camera_list = []
58
+
59
+ for id, c in enumerate(cam_infos):
60
+ camera_list.append(loadCam(args, id, c, resolution_scale))
61
+
62
+ return camera_list
63
+
64
+ def camera_to_JSON(id, camera : Camera):
65
+ Rt = np.zeros((4, 4))
66
+ Rt[:3, :3] = camera.R.transpose()
67
+ Rt[:3, 3] = camera.T
68
+ Rt[3, 3] = 1.0
69
+
70
+ W2C = np.linalg.inv(Rt)
71
+ pos = W2C[:3, 3]
72
+ rot = W2C[:3, :3]
73
+ serializable_array_2d = [x.tolist() for x in rot]
74
+ camera_entry = {
75
+ 'id' : id,
76
+ 'img_name' : camera.image_name,
77
+ 'width' : camera.width,
78
+ 'height' : camera.height,
79
+ 'position': pos.tolist(),
80
+ 'rotation': serializable_array_2d,
81
+ 'fy' : fov2focal(camera.FovY, camera.height),
82
+ 'fx' : fov2focal(camera.FovX, camera.width)
83
+ }
84
+ return camera_entry
utils/scene/utils/general_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import sys
14
+ from datetime import datetime
15
+ import numpy as np
16
+ import random
17
+
18
+ def inverse_sigmoid(x):
19
+ return torch.log(x/(1-x))
20
+
21
+ def PILtoTorch(pil_image, resolution):
22
+ resized_image_PIL = pil_image.resize(resolution)
23
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
24
+ if len(resized_image.shape) == 3:
25
+ return resized_image.permute(2, 0, 1)
26
+ else:
27
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28
+
29
+ def get_expon_lr_func(
30
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
31
+ ):
32
+ """
33
+ Copied from Plenoxels
34
+
35
+ Continuous learning rate decay function. Adapted from JaxNeRF
36
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
37
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
38
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
39
+ function of lr_delay_mult, such that the initial learning rate is
40
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
41
+ to the normal learning rate when steps>lr_delay_steps.
42
+ :param conf: config subtree 'lr' or similar
43
+ :param max_steps: int, the number of steps during optimization.
44
+ :return HoF which takes step as input
45
+ """
46
+
47
+ def helper(step):
48
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
49
+ # Disable this parameter
50
+ return 0.0
51
+ if lr_delay_steps > 0:
52
+ # A kind of reverse cosine decay.
53
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
54
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
55
+ )
56
+ else:
57
+ delay_rate = 1.0
58
+ t = np.clip(step / max_steps, 0, 1)
59
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
60
+ return delay_rate * log_lerp
61
+
62
+ return helper
63
+
64
+ def strip_lowerdiag(L):
65
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
66
+
67
+ uncertainty[:, 0] = L[:, 0, 0]
68
+ uncertainty[:, 1] = L[:, 0, 1]
69
+ uncertainty[:, 2] = L[:, 0, 2]
70
+ uncertainty[:, 3] = L[:, 1, 1]
71
+ uncertainty[:, 4] = L[:, 1, 2]
72
+ uncertainty[:, 5] = L[:, 2, 2]
73
+ return uncertainty
74
+
75
+ def strip_symmetric(sym):
76
+ return strip_lowerdiag(sym)
77
+
78
+ def build_rotation(r):
79
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
80
+
81
+ q = r / norm[:, None]
82
+
83
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
84
+
85
+ r = q[:, 0]
86
+ x = q[:, 1]
87
+ y = q[:, 2]
88
+ z = q[:, 3]
89
+
90
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
91
+ R[:, 0, 1] = 2 * (x*y - r*z)
92
+ R[:, 0, 2] = 2 * (x*z + r*y)
93
+ R[:, 1, 0] = 2 * (x*y + r*z)
94
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
95
+ R[:, 1, 2] = 2 * (y*z - r*x)
96
+ R[:, 2, 0] = 2 * (x*z - r*y)
97
+ R[:, 2, 1] = 2 * (y*z + r*x)
98
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
99
+ return R
100
+
101
+ def build_scaling_rotation(s, r):
102
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
103
+ R = build_rotation(r)
104
+
105
+ L[:,0,0] = s[:,0]
106
+ L[:,1,1] = s[:,1]
107
+ L[:,2,2] = s[:,2]
108
+
109
+ L = R @ L
110
+ return L
111
+
112
+ def safe_state(silent):
113
+ old_f = sys.stdout
114
+ class F:
115
+ def __init__(self, silent):
116
+ self.silent = silent
117
+
118
+ def write(self, x):
119
+ if not self.silent:
120
+ if x.endswith("\n"):
121
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
122
+ else:
123
+ old_f.write(x)
124
+
125
+ def flush(self):
126
+ old_f.flush()
127
+
128
+ sys.stdout = F(silent)
129
+
130
+ random.seed(0)
131
+ np.random.seed(0)
132
+ torch.manual_seed(0)
133
+ torch.cuda.set_device(torch.device("cuda:0"))
utils/scene/utils/graphics_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ import numpy as np
15
+ from typing import NamedTuple
16
+
17
+ class BasicPointCloud(NamedTuple):
18
+ points : np.array
19
+ colors : np.array
20
+ normals : np.array
21
+
22
+ def geom_transform_points(points, transf_matrix):
23
+ P, _ = points.shape
24
+ ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
25
+ points_hom = torch.cat([points, ones], dim=1)
26
+ points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
27
+
28
+ denom = points_out[..., 3:] + 0.0000001
29
+ return (points_out[..., :3] / denom).squeeze(dim=0)
30
+
31
+ def getWorld2View(R, t):
32
+ Rt = np.zeros((4, 4))
33
+ Rt[:3, :3] = R.transpose()
34
+ Rt[:3, 3] = t
35
+ Rt[3, 3] = 1.0
36
+ return np.float32(Rt)
37
+
38
+ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
39
+ Rt = np.zeros((4, 4))
40
+ Rt[:3, :3] = R.transpose()
41
+ Rt[:3, 3] = t
42
+ Rt[3, 3] = 1.0
43
+
44
+ C2W = np.linalg.inv(Rt)
45
+ cam_center = C2W[:3, 3]
46
+ cam_center = (cam_center + translate) * scale
47
+ C2W[:3, 3] = cam_center
48
+ Rt = np.linalg.inv(C2W)
49
+ return np.float32(Rt)
50
+
51
+ def getProjectionMatrix(znear, zfar, fovX, fovY, crop_box=None, width=None, height=None):
52
+ tanHalfFovY = math.tan((fovY / 2))
53
+ tanHalfFovX = math.tan((fovX / 2))
54
+
55
+ top = tanHalfFovY * znear
56
+ bottom = -top
57
+ right = tanHalfFovX * znear
58
+ left = -right
59
+
60
+ frustum_width = right - left
61
+ frustum_height = top - bottom
62
+
63
+ if crop_box is not None:
64
+ assert width is not None and height is not None
65
+ x, y, w, h = crop_box
66
+ left = left + x / width * frustum_width
67
+ right = left + w / width * frustum_width
68
+ top = top - y / height * frustum_height
69
+ bottom = top - h / height * frustum_height
70
+
71
+ P = torch.zeros(4, 4)
72
+
73
+ z_sign = 1.0
74
+
75
+ P[0, 0] = 2.0 * znear / (right - left)
76
+ P[1, 1] = 2.0 * znear / (top - bottom)
77
+ P[0, 2] = (right + left) / (right - left)
78
+ P[1, 2] = (top + bottom) / (top - bottom)
79
+ P[3, 2] = z_sign
80
+ P[2, 2] = z_sign * zfar / (zfar - znear)
81
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
82
+ return P
83
+
84
+ def fov2focal(fov, pixels):
85
+ return pixels / (2 * math.tan(fov / 2))
86
+
87
+ def focal2fov(focal, pixels):
88
+ return 2*math.atan(pixels/(2*focal))
utils/scene/utils/image_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+
14
+ def mse(img1, img2):
15
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16
+
17
+ def psnr(img1, img2):
18
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
utils/scene/utils/loss_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch.autograd import Variable
15
+ from math import exp
16
+
17
+ def l1_loss(network_output, gt, reduce=True):
18
+ l1_loss = torch.abs((network_output - gt))
19
+ return l1_loss.mean() if reduce else l1_loss
20
+
21
+ def l2_loss(network_output, gt):
22
+ return ((network_output - gt) ** 2).mean()
23
+
24
+ def gaussian(window_size, sigma):
25
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
26
+ return gauss / gauss.sum()
27
+
28
+ def create_window(window_size, channel):
29
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
30
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
31
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
32
+ return window
33
+
34
+ def ssim(img1, img2, window_size=11, size_average=True):
35
+ channel = img1.size(-3)
36
+ window = create_window(window_size, channel)
37
+
38
+ if img1.is_cuda:
39
+ window = window.cuda(img1.get_device())
40
+ window = window.type_as(img1)
41
+
42
+ return _ssim(img1, img2, window, window_size, channel, size_average)
43
+
44
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
45
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
46
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
47
+
48
+ mu1_sq = mu1.pow(2)
49
+ mu2_sq = mu2.pow(2)
50
+ mu1_mu2 = mu1 * mu2
51
+
52
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
53
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
54
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
55
+
56
+ C1 = 0.01 ** 2
57
+ C2 = 0.03 ** 2
58
+
59
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
60
+
61
+ if size_average:
62
+ return ssim_map.mean()
63
+ else:
64
+ return ssim_map.mean(1).mean(1).mean(1)
65
+
utils/scene/utils/sh_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The PlenOctree Authors.
2
+ # Redistribution and use in source and binary forms, with or without
3
+ # modification, are permitted provided that the following conditions are met:
4
+ #
5
+ # 1. Redistributions of source code must retain the above copyright notice,
6
+ # this list of conditions and the following disclaimer.
7
+ #
8
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
9
+ # this list of conditions and the following disclaimer in the documentation
10
+ # and/or other materials provided with the distribution.
11
+ #
12
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22
+ # POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ import torch
25
+
26
+ C0 = 0.28209479177387814
27
+ C1 = 0.4886025119029199
28
+ C2 = [
29
+ 1.0925484305920792,
30
+ -1.0925484305920792,
31
+ 0.31539156525252005,
32
+ -1.0925484305920792,
33
+ 0.5462742152960396
34
+ ]
35
+ C3 = [
36
+ -0.5900435899266435,
37
+ 2.890611442640554,
38
+ -0.4570457994644658,
39
+ 0.3731763325901154,
40
+ -0.4570457994644658,
41
+ 1.445305721320277,
42
+ -0.5900435899266435
43
+ ]
44
+ C4 = [
45
+ 2.5033429417967046,
46
+ -1.7701307697799304,
47
+ 0.9461746957575601,
48
+ -0.6690465435572892,
49
+ 0.10578554691520431,
50
+ -0.6690465435572892,
51
+ 0.47308734787878004,
52
+ -1.7701307697799304,
53
+ 0.6258357354491761,
54
+ ]
55
+
56
+
57
+ def eval_sh(deg, sh, dirs):
58
+ """
59
+ Evaluate spherical harmonics at unit directions
60
+ using hardcoded SH polynomials.
61
+ Works with torch/np/jnp.
62
+ ... Can be 0 or more batch dimensions.
63
+ Args:
64
+ deg: int SH deg. Currently, 0-3 supported
65
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66
+ dirs: jnp.ndarray unit directions [..., 3]
67
+ Returns:
68
+ [..., C]
69
+ """
70
+ assert deg <= 4 and deg >= 0
71
+ coeff = (deg + 1) ** 2
72
+ assert sh.shape[-1] >= coeff
73
+
74
+ result = C0 * sh[..., 0]
75
+ if deg > 0:
76
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
+ result = (result -
78
+ C1 * y * sh[..., 1] +
79
+ C1 * z * sh[..., 2] -
80
+ C1 * x * sh[..., 3])
81
+
82
+ if deg > 1:
83
+ xx, yy, zz = x * x, y * y, z * z
84
+ xy, yz, xz = x * y, y * z, x * z
85
+ result = (result +
86
+ C2[0] * xy * sh[..., 4] +
87
+ C2[1] * yz * sh[..., 5] +
88
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89
+ C2[3] * xz * sh[..., 7] +
90
+ C2[4] * (xx - yy) * sh[..., 8])
91
+
92
+ if deg > 2:
93
+ result = (result +
94
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95
+ C3[1] * xy * z * sh[..., 10] +
96
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99
+ C3[5] * z * (xx - yy) * sh[..., 14] +
100
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101
+
102
+ if deg > 3:
103
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112
+ return result
113
+
114
+ def RGB2SH(rgb):
115
+ return (rgb - 0.5) / C0
116
+
117
+ def SH2RGB(sh):
118
+ return sh * C0 + 0.5
utils/scene/utils/system_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from errno import EEXIST
13
+ from os import makedirs, path
14
+ import os
15
+
16
+ def mkdir_p(folder_path):
17
+ # Creates a directory. equivalent to using mkdir -p on the command line
18
+ try:
19
+ makedirs(folder_path)
20
+ except OSError as exc: # Python >2.5
21
+ if exc.errno == EEXIST and path.isdir(folder_path):
22
+ pass
23
+ else:
24
+ raise
25
+
26
+ def searchForMaxIteration(folder):
27
+ saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28
+ return max(saved_iters)
zoedepth/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
zoedepth/data/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
zoedepth/data/data_mono.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ # This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee
26
+
27
+ import itertools
28
+ import os
29
+ import random
30
+ from random import choice
31
+
32
+ import numpy as np
33
+ import cv2
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.utils.data.distributed
37
+ from zoedepth.utils.easydict import EasyDict as edict
38
+ from PIL import Image, ImageOps
39
+ from torch.utils.data import DataLoader, Dataset
40
+ from torchvision import transforms
41
+
42
+ from zoedepth.utils.config import change_dataset
43
+
44
+ from .ddad import get_ddad_loader
45
+ from .diml_indoor_test import get_diml_indoor_loader
46
+ from .diml_outdoor_test import get_diml_outdoor_loader
47
+ from .diode import get_diode_loader
48
+ from .hypersim import get_hypersim_loader
49
+ from .ibims import get_ibims_loader
50
+ from .sun_rgbd_loader import get_sunrgbd_loader
51
+ from .vkitti import get_vkitti_loader
52
+ from .vkitti2 import get_vkitti2_loader
53
+ from .places365 import get_places365_loader, Places365
54
+ from .marigold_nyu import get_marigold_nyu_loader, MarigoldNYU
55
+
56
+ from .preprocess import CropParams, get_white_border, get_black_border
57
+
58
+
59
+ def _is_pil_image(img):
60
+ return isinstance(img, Image.Image)
61
+
62
+
63
+ def _is_numpy_image(img):
64
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
65
+
66
+
67
+ def preprocessing_transforms(mode, **kwargs):
68
+ return transforms.Compose([
69
+ ToTensor(mode=mode, **kwargs)
70
+ ])
71
+
72
+
73
+ class DepthDataLoader(object):
74
+ def __init__(self, config, mode, device='cpu', transform=None, **kwargs):
75
+ """
76
+ Data loader for depth datasets
77
+
78
+ Args:
79
+ config (dict): Config dictionary. Refer to utils/config.py
80
+ mode (str): "train" or "online_eval"
81
+ device (str, optional): Device to load the data on. Defaults to 'cpu'.
82
+ transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None.
83
+ """
84
+
85
+ self.config = config
86
+
87
+ if config.dataset == 'ibims':
88
+ self.data = get_ibims_loader(config, batch_size=1, num_workers=1)
89
+ return
90
+
91
+ if config.dataset == 'sunrgbd':
92
+ self.data = get_sunrgbd_loader(
93
+ data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1)
94
+ return
95
+
96
+ if config.dataset == 'diml_indoor':
97
+ self.data = get_diml_indoor_loader(
98
+ data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1)
99
+ return
100
+
101
+ if config.dataset == 'diml_outdoor':
102
+ self.data = get_diml_outdoor_loader(
103
+ data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1)
104
+ return
105
+
106
+ if "diode" in config.dataset:
107
+ self.data = get_diode_loader(
108
+ config[config.dataset+"_root"], batch_size=1, num_workers=1)
109
+ return
110
+
111
+ if config.dataset == 'hypersim_test':
112
+ self.data = get_hypersim_loader(
113
+ config.hypersim_test_root, batch_size=1, num_workers=1)
114
+ return
115
+
116
+ if config.dataset == 'vkitti':
117
+ self.data = get_vkitti_loader(
118
+ config.vkitti_root, batch_size=1, num_workers=1)
119
+ return
120
+
121
+ if config.dataset == 'vkitti2':
122
+ self.data = get_vkitti2_loader(
123
+ config.vkitti2_root, batch_size=1, num_workers=1)
124
+ return
125
+
126
+ if config.dataset == 'ddad':
127
+ self.data = get_ddad_loader(config.ddad_root, resize_shape=(
128
+ 352, 1216), batch_size=1, num_workers=1)
129
+ return
130
+
131
+ img_size = self.config.get("img_size", None)
132
+ img_size = img_size if self.config.get(
133
+ "do_input_resize", False) else None
134
+
135
+ if transform is None:
136
+ transform = preprocessing_transforms(mode, size=img_size)
137
+
138
+ if mode == 'train':
139
+
140
+ Dataset = DataLoadPreprocess
141
+ self.training_samples = Dataset(
142
+ config, mode, transform=transform, device=device)
143
+
144
+ if config.distributed and not config.debug_mode:
145
+ self.train_sampler = torch.utils.data.distributed.DistributedSampler(
146
+ self.training_samples)
147
+ else:
148
+ self.train_sampler = None
149
+
150
+ if not config.debug_mode:
151
+ self.data = DataLoader(self.training_samples,
152
+ batch_size=config.batch_size,
153
+ shuffle=(self.train_sampler is None),
154
+ num_workers=config.workers,
155
+ pin_memory=True,
156
+ persistent_workers=True,
157
+ # prefetch_factor=2,
158
+ sampler=self.train_sampler)
159
+ else:
160
+ self.data = DataLoader(self.training_samples,
161
+ batch_size=config.batch_size,
162
+ shuffle=(self.train_sampler is None),
163
+ num_workers=0,
164
+ pin_memory=True,
165
+ # prefetch_factor=2,
166
+ sampler=self.train_sampler)
167
+
168
+ elif mode == 'online_eval':
169
+ self.testing_samples = DataLoadPreprocess(
170
+ config, mode, transform=transform)
171
+ if config.distributed: # redundant. here only for readability and to be more explicit
172
+ # Give whole test set to all processes (and report evaluation only on one) regardless
173
+ self.eval_sampler = None
174
+ else:
175
+ self.eval_sampler = None
176
+ self.data = DataLoader(self.testing_samples, 1,
177
+ shuffle=kwargs.get("shuffle_test", False),
178
+ num_workers=1,
179
+ pin_memory=False,
180
+ sampler=self.eval_sampler)
181
+
182
+ elif mode == 'test':
183
+ self.testing_samples = DataLoadPreprocess(
184
+ config, mode, transform=transform)
185
+ self.data = DataLoader(self.testing_samples,
186
+ 1, shuffle=False, num_workers=1)
187
+
188
+ else:
189
+ print(
190
+ 'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
191
+
192
+
193
+ def repetitive_roundrobin(*iterables):
194
+ """
195
+ cycles through iterables but sample wise
196
+ first yield first sample from first iterable then first sample from second iterable and so on
197
+ then second sample from first iterable then second sample from second iterable and so on
198
+
199
+ If one iterable is shorter than the others, it is repeated until all iterables are exhausted
200
+ repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E
201
+ """
202
+ # Repetitive roundrobin
203
+ iterables_ = [iter(it) for it in iterables]
204
+ exhausted = [False] * len(iterables)
205
+ while not all(exhausted):
206
+ for i, it in enumerate(iterables_):
207
+ try:
208
+ yield next(it)
209
+ except StopIteration:
210
+ exhausted[i] = True
211
+ iterables_[i] = itertools.cycle(iterables[i])
212
+ # First elements may get repeated if one iterable is shorter than the others
213
+ yield next(iterables_[i])
214
+
215
+
216
+ class RepetitiveRoundRobinDataLoader(object):
217
+ def __init__(self, *dataloaders):
218
+ self.dataloaders = dataloaders
219
+
220
+ def __iter__(self):
221
+ return repetitive_roundrobin(*self.dataloaders)
222
+
223
+ def __len__(self):
224
+ # First samples get repeated, thats why the plus one
225
+ return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1)
226
+
227
+
228
+ class MixedNYUKITTI(object):
229
+ def __init__(self, config, mode, device='cpu', **kwargs):
230
+ config = edict(config)
231
+ config.workers = config.workers // 2
232
+ self.config = config
233
+ nyu_conf = change_dataset(edict(config), 'nyu')
234
+ kitti_conf = change_dataset(edict(config), 'kitti')
235
+
236
+ # make nyu default for testing
237
+ self.config = config = nyu_conf
238
+ img_size = self.config.get("img_size", None)
239
+ img_size = img_size if self.config.get(
240
+ "do_input_resize", False) else None
241
+ if mode == 'train':
242
+ nyu_loader = DepthDataLoader(
243
+ nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
244
+ kitti_loader = DepthDataLoader(
245
+ kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
246
+ # It has been changed to repetitive roundrobin
247
+ self.data = RepetitiveRoundRobinDataLoader(
248
+ nyu_loader, kitti_loader)
249
+ else:
250
+ self.data = DepthDataLoader(nyu_conf, mode, device=device).data
251
+
252
+ class MixedNYUPlaces365(object):
253
+ def __init__(self, config, mode, device='cpu', **kwargs):
254
+ config = edict(config)
255
+ config.workers = config.workers // 2
256
+ self.config = config
257
+ nyu_conf = change_dataset(edict(config), 'nyu')
258
+ places365_conf = change_dataset(edict(config), 'places365')
259
+
260
+ # make nyu default for testing
261
+ self.config = config = nyu_conf
262
+ img_size = self.config.get("img_size", None)
263
+ img_size = img_size if self.config.get(
264
+ "do_input_resize", False) else None
265
+ if mode == 'train':
266
+ nyu_loader = DepthDataLoader(
267
+ nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
268
+ places365_loader = DepthDataLoader(
269
+ places365_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
270
+ # It has been changed to repetitive roundrobin
271
+ self.data = RepetitiveRoundRobinDataLoader(
272
+ nyu_loader, places365_loader)
273
+ else:
274
+ self.data = DepthDataLoader(nyu_conf, mode, device=device).data
275
+
276
+ def remove_leading_slash(s):
277
+ if s[0] == '/' or s[0] == '\\':
278
+ return s[1:]
279
+ return s
280
+
281
+
282
+ class CachedReader:
283
+ def __init__(self, shared_dict=None):
284
+ if shared_dict:
285
+ self._cache = shared_dict
286
+ else:
287
+ self._cache = {}
288
+
289
+ def open(self, fpath):
290
+ im = self._cache.get(fpath, None)
291
+ if im is None:
292
+ im = self._cache[fpath] = Image.open(fpath)
293
+ return im
294
+
295
+
296
+ class ImReader:
297
+ def __init__(self):
298
+ pass
299
+
300
+ # @cache
301
+ def open(self, fpath):
302
+ return Image.open(fpath)
303
+
304
+
305
+ class DataLoadPreprocess(Dataset):
306
+ def __init__(self, config, mode, transform=None, is_for_online_eval=False, device="cpu", **kwargs):
307
+ self.config = config
308
+ if mode == 'online_eval':
309
+ with open(config.filenames_file_eval, 'r') as f:
310
+ self.filenames = f.readlines()
311
+ else:
312
+ with open(config.filenames_file, 'r') as f:
313
+ self.filenames = f.readlines()
314
+
315
+ self.device = torch.device(device)
316
+ self.mode = mode
317
+ self.transform = transform
318
+ self.to_tensor = ToTensor(mode)
319
+ self.is_for_online_eval = is_for_online_eval
320
+ if config.use_shared_dict:
321
+ self.reader = CachedReader(config.shared_dict)
322
+ else:
323
+ self.reader = ImReader()
324
+
325
+ if config.dataset == "places365" or config.inpaint_task_probability > 0:
326
+ places365_conf = change_dataset(edict(config), 'places365')
327
+ self.places365_data = self.data = Places365(places365_conf.places365_root, places365_conf.places365_depth_root, places365_conf.places365_depth_masks_root, randomize_masks=places365_conf.get("randomize_masks", True), debug_mode=self.config.debug_mode)
328
+
329
+ if config.dataset == "marigold_nyu":
330
+ self.marigold_data = self.data = MarigoldNYU(config.nyu_dir_root, config.marigold_depth_root, debug_mode=self.config.debug_mode)
331
+ self.config.avoid_boundary = True
332
+
333
+ def postprocess(self, sample):
334
+ return sample
335
+
336
+ def __getitem__(self, idx):
337
+ sample_path = self.filenames[idx] if self.config.dataset not in ('places365', "marigold_nyu") else self.filenames[0]
338
+ focal = float(sample_path.split()[2])
339
+ sample = {}
340
+
341
+ if self.mode == 'train':
342
+ depth_mask = None
343
+ if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5:
344
+ image_path = os.path.join(
345
+ self.config.data_path, remove_leading_slash(sample_path.split()[3]))
346
+ depth_path = os.path.join(
347
+ self.config.gt_path, remove_leading_slash(sample_path.split()[4]))
348
+
349
+ image = self.reader.open(image_path)
350
+ depth_gt = self.reader.open(depth_path)
351
+ w, h = image.size
352
+
353
+ elif self.config.dataset == 'places365':
354
+ image, depth_gt, depth_mask, image_path, depth_path, _ = self.places365_data[idx]
355
+ h, w = image.shape[:2]
356
+
357
+ if image.ndim == 2:
358
+ image = image.reshape(image.shape[0], image.shape[1], 1)
359
+ image = np.repeat(image, 3, axis=-1)
360
+
361
+ elif self.config.dataset == 'marigold_nyu':
362
+ image, depth_gt, marigold_gt, image_path, depth_path = self.marigold_data[idx]
363
+
364
+ h, w = image.shape[:2]
365
+
366
+ if image.ndim == 2:
367
+ image = image.reshape(image.shape[0], image.shape[1], 1)
368
+ image = np.repeat(image, 3, axis=-1)
369
+
370
+ else:
371
+ image_path = os.path.join(
372
+ self.config.data_path, remove_leading_slash(sample_path.split()[0]))
373
+ depth_path = os.path.join(
374
+ self.config.gt_path, remove_leading_slash(sample_path.split()[1]))
375
+
376
+ image = self.reader.open(image_path)
377
+ depth_gt = self.reader.open(depth_path)
378
+ w, h = image.size
379
+
380
+ if self.config.inpaint_task_probability > 0:
381
+ _, _, depth_mask, _, _, _ = self.places365_data[idx]
382
+
383
+ if self.config.do_kb_crop:
384
+ height = image.height
385
+ width = image.width
386
+ top_margin = int(height - 352)
387
+ left_margin = int((width - 1216) / 2)
388
+ depth_gt = depth_gt.crop(
389
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
390
+ image = image.crop(
391
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
392
+
393
+ # Avoid blank boundaries due to pixel registration?
394
+ # Train images have white border. Test images have black border.
395
+ if self.config.dataset in ('nyu', 'marigold_nyu') and self.config.avoid_boundary:
396
+ # print("Avoiding Blank Boundaries!")
397
+ # We just crop and pad again with reflect padding to original size
398
+ # original_size = image.size
399
+ #crop_params = get_white_border(np.array(255*image, dtype=np.uint8))
400
+ # crop image down from 640x480 to 624x464
401
+ crop_params = CropParams(8, 472, 8, 632)
402
+
403
+ image = image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
404
+ depth_gt = depth_gt[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
405
+
406
+ # Use reflect padding to fill the blank
407
+ #image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect')
408
+ #image = Image.fromarray(image)
409
+
410
+ #depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), 'constant', constant_values=0)
411
+ #depth_gt = Image.fromarray(depth_gt)
412
+
413
+ if self.config.dataset == "marigold_nyu":
414
+ marigold_gt = marigold_gt[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
415
+
416
+ if self.config.do_random_rotate and (self.config.aug) and self.config.dataset not in ('places365', "marigold_nyu"):
417
+ random_angle = (random.random() - 0.5) * 2 * self.config.degree
418
+ image = self.rotate_image(image, random_angle)
419
+ depth_gt = self.rotate_image(
420
+ depth_gt, random_angle, flag=Image.NEAREST)
421
+
422
+ if self.config.dataset not in ('places365', "marigold_nyu"):
423
+ image = np.asarray(image, dtype=np.float32) / 255.0
424
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
425
+ depth_gt = np.expand_dims(depth_gt, axis=2)
426
+
427
+ if self.config.dataset in ('nyu', 'marigold_nyu'):
428
+ depth_gt = depth_gt / 1000.0
429
+ elif self.config.dataset != 'places365':
430
+ depth_gt = depth_gt / 256.0
431
+
432
+ if self.config.aug and (self.config.random_crop) and self.config.dataset not in ('places365', "marigold_nyu"):
433
+ image, depth_gt = self.random_crop(
434
+ image, depth_gt, self.config.input_height, self.config.input_width)
435
+
436
+ if self.config.aug and self.config.random_translate and self.config.dataset not in ('places365', "marigold_nyu"):
437
+ # print("Random Translation!")
438
+ image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation)
439
+
440
+ mask = np.logical_and(depth_gt > self.config.min_depth,
441
+ depth_gt < self.config.max_depth).squeeze()[None, ...]
442
+
443
+ is_inpainting_sample = self.config.inpaint_task_probability > 0 and (torch.rand(1).item() < self.config.inpaint_task_probability)
444
+
445
+ def randomly_scale_depth(depth_to_scale):
446
+ # scale the mask
447
+ max_scale_factor = self.config.max_depth / depth_to_scale.max()
448
+ min_scale_factor = self.config.min_depth / depth_to_scale.min()
449
+
450
+ scale_factor = torch.rand(1).item() * (max_scale_factor - min_scale_factor) + min_scale_factor
451
+ scaled_depth = depth_to_scale * scale_factor
452
+
453
+ scaled_depth = scaled_depth.clip(self.config.min_depth, self.config.max_depth)
454
+
455
+ return scaled_depth
456
+
457
+ if self.config.dataset in ("marigold_nyu"):
458
+ marigold_mask = (marigold_gt > -1).squeeze()[None, ...]
459
+
460
+ if is_inpainting_sample and self.config.random_inpainting_scaling:
461
+ marigold_gt = randomly_scale_depth(marigold_gt)
462
+
463
+ marigold_gt[~marigold_mask[0]] = 0
464
+
465
+ depth_gt = marigold_gt
466
+ mask = marigold_mask
467
+
468
+ image, depth_gt, mask = self.train_preprocess(image, depth_gt, mask)
469
+
470
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal,
471
+ 'mask': mask, **sample}
472
+
473
+ if self.config["depth_channel_mask_augment"]:
474
+ if self.config.dataset in ("marigold_nyu",):
475
+ if (not self.config.inpaint_task_probability > 0) and depth_mask is None:
476
+ depth_mask = np.zeros_like(depth_gt)
477
+ elif self.config.inpaint_task_probability > 0:
478
+ # we randomly mask with places365, or provide no sparse input at all
479
+ if is_inpainting_sample:
480
+ # upsample depth_mask to match depth_gt
481
+ depth_mask = torch.nn.functional.interpolate(torch.from_numpy(depth_mask).permute(2, 0, 1).unsqueeze(0), size=depth_gt.shape[:2], mode='nearest').squeeze(0).permute(1, 2, 0).numpy()
482
+ else:
483
+ depth_mask = np.zeros_like(depth_gt)
484
+
485
+ sample["masked_depth"] = depth_gt * depth_mask
486
+
487
+ else:
488
+ if self.mode == 'online_eval':
489
+ data_path = self.config.data_path_eval
490
+ else:
491
+ data_path = self.config.data_path
492
+
493
+ image_path = os.path.join(
494
+ data_path, remove_leading_slash(sample_path.split()[0]))
495
+ image = np.asarray(self.reader.open(image_path),
496
+ dtype=np.float32) / 255.0
497
+
498
+ if self.mode == 'online_eval':
499
+ gt_path = self.config.gt_path_eval
500
+ depth_path = os.path.join(
501
+ gt_path, remove_leading_slash(sample_path.split()[1]))
502
+ has_valid_depth = False
503
+ try:
504
+ depth_gt = self.reader.open(depth_path)
505
+ has_valid_depth = True
506
+ except IOError:
507
+ depth_gt = False
508
+ # print('Missing gt for {}'.format(image_path))
509
+
510
+ if has_valid_depth:
511
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
512
+ depth_gt = np.expand_dims(depth_gt, axis=2)
513
+ if self.config.dataset == 'nyu':
514
+ depth_gt = depth_gt / 1000.0
515
+ elif self.config.dataset != 'places365':
516
+ depth_gt = depth_gt / 256.0
517
+
518
+ mask = np.logical_and(
519
+ depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...]
520
+ else:
521
+ mask = False
522
+
523
+ if self.config.do_kb_crop:
524
+ height = image.shape[0]
525
+ width = image.shape[1]
526
+ top_margin = int(height - 352)
527
+ left_margin = int((width - 1216) / 2)
528
+ image = image[top_margin:top_margin + 352,
529
+ left_margin:left_margin + 1216, :]
530
+ if self.mode == 'online_eval' and has_valid_depth:
531
+ depth_gt = depth_gt[top_margin:top_margin +
532
+ 352, left_margin:left_margin + 1216, :]
533
+
534
+ if self.mode == 'online_eval':
535
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth,
536
+ 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1],
537
+ 'mask': mask}
538
+ else:
539
+ sample = {'image': image, 'focal': focal}
540
+
541
+ if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']):
542
+ if (self.config.dataset not in ('places365', "marigold_nyu")):
543
+ mask = np.logical_and(depth_gt > self.config.min_depth,
544
+ depth_gt < self.config.max_depth).squeeze()[None, ...]
545
+ sample['mask'] = mask
546
+
547
+ if self.transform:
548
+ sample = self.transform(sample)
549
+
550
+ sample = self.postprocess(sample)
551
+ sample['dataset'] = self.config.dataset
552
+
553
+ if self.config.dataset != 'places365':
554
+ sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]}
555
+ else:
556
+ sample = {**sample, 'image_path': image_path, 'depth_path': depth_path}
557
+
558
+ return sample
559
+
560
+ def rotate_image(self, image, angle, flag=Image.BILINEAR):
561
+ result = image.rotate(angle, resample=flag)
562
+ return result
563
+
564
+ def random_crop(self, img, depth, height, width):
565
+ assert img.shape[0] >= height
566
+ assert img.shape[1] >= width
567
+ assert img.shape[0] == depth.shape[0]
568
+ assert img.shape[1] == depth.shape[1]
569
+ x = random.randint(0, img.shape[1] - width)
570
+ y = random.randint(0, img.shape[0] - height)
571
+ img = img[y:y + height, x:x + width, :]
572
+ depth = depth[y:y + height, x:x + width, :]
573
+
574
+ return img, depth
575
+
576
+ def random_translate(self, img, depth, max_t=20):
577
+ assert img.shape[0] == depth.shape[0]
578
+ assert img.shape[1] == depth.shape[1]
579
+ p = self.config.translate_prob
580
+ do_translate = random.random()
581
+ if do_translate > p:
582
+ return img, depth
583
+ x = random.randint(-max_t, max_t)
584
+ y = random.randint(-max_t, max_t)
585
+ M = np.float32([[1, 0, x], [0, 1, y]])
586
+ # print(img.shape, depth.shape)
587
+ img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
588
+ depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0]))
589
+ depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it
590
+ # print("after", img.shape, depth.shape)
591
+ return img, depth
592
+
593
+ def train_preprocess(self, image, depth_gt, mask):
594
+ if self.config.aug:
595
+ # Random flipping
596
+ do_flip = random.random()
597
+ if do_flip > 0.5:
598
+ # image is H x W x 3
599
+ image = (image[:, ::-1, :]).copy()
600
+ # depth_gt is H x W x 1
601
+ depth_gt = (depth_gt[:, ::-1, :]).copy()
602
+ # mask is B x H x W
603
+ mask = (mask[:, :, ::-1]).copy()
604
+
605
+ # Random gamma, brightness, color augmentation
606
+ do_augment = random.random()
607
+ if do_augment > 0.5:
608
+ image = self.augment_image(image)
609
+
610
+ return image, depth_gt, mask
611
+
612
+ def augment_image(self, image):
613
+ # gamma augmentation
614
+ gamma = random.uniform(0.9, 1.1)
615
+ image_aug = image ** gamma
616
+
617
+ # brightness augmentation
618
+ if self.config.dataset == 'nyu':
619
+ brightness = random.uniform(0.75, 1.25)
620
+ else:
621
+ brightness = random.uniform(0.9, 1.1)
622
+ image_aug = image_aug * brightness
623
+
624
+ # color augmentation
625
+ colors = np.random.uniform(0.9, 1.1, size=3)
626
+ white = np.ones((image.shape[0], image.shape[1]))
627
+ color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
628
+ image_aug *= color_image
629
+ image_aug = np.clip(image_aug, 0, 1)
630
+
631
+ return image_aug
632
+
633
+ def __len__(self):
634
+ return len(self.data) if (self.config.dataset in ('places365', "marigold_nyu") and self.mode != 'online_eval') else len(self.filenames)
635
+
636
+
637
+ class ToTensor(object):
638
+ def __init__(self, mode, do_normalize=False, size=None):
639
+ self.mode = mode
640
+ self.normalize = transforms.Normalize(
641
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity()
642
+ self.size = size
643
+ if size is not None:
644
+ self.resize = transforms.Resize(size=size)
645
+ else:
646
+ self.resize = nn.Identity()
647
+
648
+ def __call__(self, sample):
649
+ image, focal = sample['image'], sample['focal']
650
+ image = self.to_tensor(image)
651
+ image = self.normalize(image)
652
+ image = self.resize(image)
653
+
654
+ if self.mode == 'test':
655
+ return {'image': image, 'focal': focal}
656
+
657
+ depth = sample['depth']
658
+ if self.mode == 'train':
659
+ depth = self.to_tensor(depth)
660
+ return {**sample, 'image': image, 'depth': depth, 'focal': focal}
661
+ else:
662
+ has_valid_depth = sample['has_valid_depth']
663
+ image = self.resize(image)
664
+ return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth,
665
+ 'image_path': sample['image_path'], 'depth_path': sample['depth_path']}
666
+
667
+ def to_tensor(self, pic):
668
+ if not (_is_pil_image(pic) or _is_numpy_image(pic)):
669
+ raise TypeError(
670
+ 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
671
+
672
+ if isinstance(pic, np.ndarray):
673
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
674
+ return img
675
+
676
+ # handle PIL Image
677
+ if pic.mode == 'I':
678
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
679
+ elif pic.mode == 'I;16':
680
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
681
+ else:
682
+ img = torch.ByteTensor(
683
+ torch.ByteStorage.from_buffer(pic.tobytes()))
684
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
685
+ if pic.mode == 'YCbCr':
686
+ nchannel = 3
687
+ elif pic.mode == 'I;16':
688
+ nchannel = 1
689
+ else:
690
+ nchannel = len(pic.mode)
691
+ img = img.view(pic.size[1], pic.size[0], nchannel)
692
+
693
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
694
+ if isinstance(img, torch.ByteTensor):
695
+ return img.float()
696
+ else:
697
+ return img
zoedepth/data/ddad.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self, resize_shape):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize(resize_shape)
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "ddad"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+
76
+ if isinstance(img, torch.ByteTensor):
77
+ return img.float()
78
+ else:
79
+ return img
80
+
81
+
82
+ class DDAD(Dataset):
83
+ def __init__(self, data_dir_root, resize_shape):
84
+ import glob
85
+
86
+ # image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
87
+ self.image_files = glob.glob(os.path.join(data_dir_root, '*.png'))
88
+ self.depth_files = [r.replace("_rgb.png", "_depth.npy")
89
+ for r in self.image_files]
90
+ self.transform = ToTensor(resize_shape)
91
+
92
+ def __getitem__(self, idx):
93
+
94
+ image_path = self.image_files[idx]
95
+ depth_path = self.depth_files[idx]
96
+
97
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
98
+ depth = np.load(depth_path) # meters
99
+
100
+ # depth[depth > 8] = -1
101
+ depth = depth[..., None]
102
+
103
+ sample = dict(image=image, depth=depth)
104
+ sample = self.transform(sample)
105
+
106
+ if idx == 0:
107
+ print(sample["image"].shape)
108
+
109
+ return sample
110
+
111
+ def __len__(self):
112
+ return len(self.image_files)
113
+
114
+
115
+ def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs):
116
+ dataset = DDAD(data_dir_root, resize_shape)
117
+ return DataLoader(dataset, batch_size, **kwargs)
zoedepth/data/diml_indoor_test.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize((480, 640))
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "diml_indoor"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+ if isinstance(img, torch.ByteTensor):
76
+ return img.float()
77
+ else:
78
+ return img
79
+
80
+
81
+ class DIML_Indoor(Dataset):
82
+ def __init__(self, data_dir_root):
83
+ import glob
84
+
85
+ # image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
86
+ self.image_files = glob.glob(os.path.join(
87
+ data_dir_root, "LR", '*', 'color', '*.png'))
88
+ self.depth_files = [r.replace("color", "depth_filled").replace(
89
+ "_c.png", "_depth_filled.png") for r in self.image_files]
90
+ self.transform = ToTensor()
91
+
92
+ def __getitem__(self, idx):
93
+ image_path = self.image_files[idx]
94
+ depth_path = self.depth_files[idx]
95
+
96
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
97
+ depth = np.asarray(Image.open(depth_path),
98
+ dtype='uint16') / 1000.0 # mm to meters
99
+
100
+ # print(np.shape(image))
101
+ # print(np.shape(depth))
102
+
103
+ # depth[depth > 8] = -1
104
+ depth = depth[..., None]
105
+
106
+ sample = dict(image=image, depth=depth)
107
+
108
+ # return sample
109
+ sample = self.transform(sample)
110
+
111
+ if idx == 0:
112
+ print(sample["image"].shape)
113
+
114
+ return sample
115
+
116
+ def __len__(self):
117
+ return len(self.image_files)
118
+
119
+
120
+ def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs):
121
+ dataset = DIML_Indoor(data_dir_root)
122
+ return DataLoader(dataset, batch_size, **kwargs)
123
+
124
+ # get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR")
125
+ # get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR")
zoedepth/data/diml_outdoor_test.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+
40
+ def __call__(self, sample):
41
+ image, depth = sample['image'], sample['depth']
42
+ image = self.to_tensor(image)
43
+ image = self.normalize(image)
44
+ depth = self.to_tensor(depth)
45
+
46
+ return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"}
47
+
48
+ def to_tensor(self, pic):
49
+
50
+ if isinstance(pic, np.ndarray):
51
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
52
+ return img
53
+
54
+ # # handle PIL Image
55
+ if pic.mode == 'I':
56
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
57
+ elif pic.mode == 'I;16':
58
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
59
+ else:
60
+ img = torch.ByteTensor(
61
+ torch.ByteStorage.from_buffer(pic.tobytes()))
62
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
63
+ if pic.mode == 'YCbCr':
64
+ nchannel = 3
65
+ elif pic.mode == 'I;16':
66
+ nchannel = 1
67
+ else:
68
+ nchannel = len(pic.mode)
69
+ img = img.view(pic.size[1], pic.size[0], nchannel)
70
+
71
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
72
+ if isinstance(img, torch.ByteTensor):
73
+ return img.float()
74
+ else:
75
+ return img
76
+
77
+
78
+ class DIML_Outdoor(Dataset):
79
+ def __init__(self, data_dir_root):
80
+ import glob
81
+
82
+ # image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
83
+ self.image_files = glob.glob(os.path.join(
84
+ data_dir_root, "*", 'outleft', '*.png'))
85
+ self.depth_files = [r.replace("outleft", "depthmap")
86
+ for r in self.image_files]
87
+ self.transform = ToTensor()
88
+
89
+ def __getitem__(self, idx):
90
+ image_path = self.image_files[idx]
91
+ depth_path = self.depth_files[idx]
92
+
93
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
94
+ depth = np.asarray(Image.open(depth_path),
95
+ dtype='uint16') / 1000.0 # mm to meters
96
+
97
+ # depth[depth > 8] = -1
98
+ depth = depth[..., None]
99
+
100
+ sample = dict(image=image, depth=depth, dataset="diml_outdoor")
101
+
102
+ # return sample
103
+ return self.transform(sample)
104
+
105
+ def __len__(self):
106
+ return len(self.image_files)
107
+
108
+
109
+ def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs):
110
+ dataset = DIML_Outdoor(data_dir_root)
111
+ return DataLoader(dataset, batch_size, **kwargs)
112
+
113
+ # get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR")
114
+ # get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR")
zoedepth/data/diode.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize(480)
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "diode"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+
76
+ if isinstance(img, torch.ByteTensor):
77
+ return img.float()
78
+ else:
79
+ return img
80
+
81
+
82
+ class DIODE(Dataset):
83
+ def __init__(self, data_dir_root):
84
+ import glob
85
+
86
+ # image paths are of the form <data_dir_root>/scene_#/scan_#/*.png
87
+ self.image_files = glob.glob(
88
+ os.path.join(data_dir_root, '*', '*', '*.png'))
89
+ self.depth_files = [r.replace(".png", "_depth.npy")
90
+ for r in self.image_files]
91
+ self.depth_mask_files = [
92
+ r.replace(".png", "_depth_mask.npy") for r in self.image_files]
93
+ self.transform = ToTensor()
94
+
95
+ def __getitem__(self, idx):
96
+ image_path = self.image_files[idx]
97
+ depth_path = self.depth_files[idx]
98
+ depth_mask_path = self.depth_mask_files[idx]
99
+
100
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
101
+ depth = np.load(depth_path) # in meters
102
+ valid = np.load(depth_mask_path) # binary
103
+
104
+ # depth[depth > 8] = -1
105
+ # depth = depth[..., None]
106
+
107
+ sample = dict(image=image, depth=depth, valid=valid)
108
+
109
+ # return sample
110
+ sample = self.transform(sample)
111
+
112
+ if idx == 0:
113
+ print(sample["image"].shape)
114
+
115
+ return sample
116
+
117
+ def __len__(self):
118
+ return len(self.image_files)
119
+
120
+
121
+ def get_diode_loader(data_dir_root, batch_size=1, **kwargs):
122
+ dataset = DIODE(data_dir_root)
123
+ return DataLoader(dataset, batch_size, **kwargs)
124
+
125
+ # get_diode_loader(data_dir_root="datasets/diode/val/outdoor")
zoedepth/data/hypersim.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import glob
26
+ import os
27
+
28
+ import h5py
29
+ import numpy as np
30
+ import torch
31
+ from PIL import Image
32
+ from torch.utils.data import DataLoader, Dataset
33
+ from torchvision import transforms
34
+
35
+
36
+ def hypersim_distance_to_depth(npyDistance):
37
+ intWidth, intHeight, fltFocal = 1024, 768, 886.81
38
+
39
+ npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape(
40
+ 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None]
41
+ npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5,
42
+ intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None]
43
+ npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32)
44
+ npyImageplane = np.concatenate(
45
+ [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2)
46
+
47
+ npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal
48
+ return npyDepth
49
+
50
+
51
+ class ToTensor(object):
52
+ def __init__(self):
53
+ # self.normalize = transforms.Normalize(
54
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
+ self.normalize = lambda x: x
56
+ self.resize = transforms.Resize((480, 640))
57
+
58
+ def __call__(self, sample):
59
+ image, depth = sample['image'], sample['depth']
60
+ image = self.to_tensor(image)
61
+ image = self.normalize(image)
62
+ depth = self.to_tensor(depth)
63
+
64
+ image = self.resize(image)
65
+
66
+ return {'image': image, 'depth': depth, 'dataset': "hypersim"}
67
+
68
+ def to_tensor(self, pic):
69
+
70
+ if isinstance(pic, np.ndarray):
71
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
72
+ return img
73
+
74
+ # # handle PIL Image
75
+ if pic.mode == 'I':
76
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
77
+ elif pic.mode == 'I;16':
78
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
79
+ else:
80
+ img = torch.ByteTensor(
81
+ torch.ByteStorage.from_buffer(pic.tobytes()))
82
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
83
+ if pic.mode == 'YCbCr':
84
+ nchannel = 3
85
+ elif pic.mode == 'I;16':
86
+ nchannel = 1
87
+ else:
88
+ nchannel = len(pic.mode)
89
+ img = img.view(pic.size[1], pic.size[0], nchannel)
90
+
91
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
92
+ if isinstance(img, torch.ByteTensor):
93
+ return img.float()
94
+ else:
95
+ return img
96
+
97
+
98
+ class HyperSim(Dataset):
99
+ def __init__(self, data_dir_root):
100
+ # image paths are of the form <data_dir_root>/<scene>/images/scene_cam_#_final_preview/*.tonemap.jpg
101
+ # depth paths are of the form <data_dir_root>/<scene>/images/scene_cam_#_final_preview/*.depth_meters.hdf5
102
+ self.image_files = glob.glob(os.path.join(
103
+ data_dir_root, '*', 'images', 'scene_cam_*_final_preview', '*.tonemap.jpg'))
104
+ self.depth_files = [r.replace("_final_preview", "_geometry_hdf5").replace(
105
+ ".tonemap.jpg", ".depth_meters.hdf5") for r in self.image_files]
106
+ self.transform = ToTensor()
107
+
108
+ def __getitem__(self, idx):
109
+ image_path = self.image_files[idx]
110
+ depth_path = self.depth_files[idx]
111
+
112
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
113
+
114
+ # depth from hdf5
115
+ depth_fd = h5py.File(depth_path, "r")
116
+ # in meters (Euclidean distance)
117
+ distance_meters = np.array(depth_fd['dataset'])
118
+ depth = hypersim_distance_to_depth(
119
+ distance_meters) # in meters (planar depth)
120
+
121
+ # depth[depth > 8] = -1
122
+ depth = depth[..., None]
123
+
124
+ sample = dict(image=image, depth=depth)
125
+ sample = self.transform(sample)
126
+
127
+ if idx == 0:
128
+ print(sample["image"].shape)
129
+
130
+ return sample
131
+
132
+ def __len__(self):
133
+ return len(self.image_files)
134
+
135
+
136
+ def get_hypersim_loader(data_dir_root, batch_size=1, **kwargs):
137
+ dataset = HyperSim(data_dir_root)
138
+ return DataLoader(dataset, batch_size, **kwargs)
zoedepth/data/ibims.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms as T
32
+
33
+
34
+ class iBims(Dataset):
35
+ def __init__(self, config):
36
+ root_folder = config.ibims_root
37
+ with open(os.path.join(root_folder, "imagelist.txt"), 'r') as f:
38
+ imglist = f.read().split()
39
+
40
+ samples = []
41
+ for basename in imglist:
42
+ img_path = os.path.join(root_folder, 'rgb', basename + ".png")
43
+ depth_path = os.path.join(root_folder, 'depth', basename + ".png")
44
+ valid_mask_path = os.path.join(
45
+ root_folder, 'mask_invalid', basename+".png")
46
+ transp_mask_path = os.path.join(
47
+ root_folder, 'mask_transp', basename+".png")
48
+
49
+ samples.append(
50
+ (img_path, depth_path, valid_mask_path, transp_mask_path))
51
+
52
+ self.samples = samples
53
+ # self.normalize = T.Normalize(
54
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
+ self.normalize = lambda x : x
56
+
57
+ def __getitem__(self, idx):
58
+ img_path, depth_path, valid_mask_path, transp_mask_path = self.samples[idx]
59
+
60
+ img = np.asarray(Image.open(img_path), dtype=np.float32) / 255.0
61
+ depth = np.asarray(Image.open(depth_path),
62
+ dtype=np.uint16).astype('float')*50.0/65535
63
+
64
+ mask_valid = np.asarray(Image.open(valid_mask_path))
65
+ mask_transp = np.asarray(Image.open(transp_mask_path))
66
+
67
+ # depth = depth * mask_valid * mask_transp
68
+ depth = np.where(mask_valid * mask_transp, depth, -1)
69
+
70
+ img = torch.from_numpy(img).permute(2, 0, 1)
71
+ img = self.normalize(img)
72
+ depth = torch.from_numpy(depth).unsqueeze(0)
73
+ return dict(image=img, depth=depth, image_path=img_path, depth_path=depth_path, dataset='ibims')
74
+
75
+ def __len__(self):
76
+ return len(self.samples)
77
+
78
+
79
+ def get_ibims_loader(config, batch_size=1, **kwargs):
80
+ dataloader = DataLoader(iBims(config), batch_size=batch_size, **kwargs)
81
+ return dataloader
zoedepth/data/marigold_nyu.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+ from random import choice
33
+
34
+
35
+ class ToTensor(object):
36
+ def __init__(self):
37
+ self.normalize = transforms.Normalize(
38
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ #self.normalize = lambda x : x
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ return {'image': image, 'depth': depth, 'dataset': "marigold_nyu"}
48
+
49
+ def to_tensor(self, pic):
50
+
51
+ if isinstance(pic, np.ndarray):
52
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
53
+ return img
54
+
55
+ # # handle PIL Image
56
+ if pic.mode == 'I':
57
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
58
+ elif pic.mode == 'I;16':
59
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
60
+ else:
61
+ img = torch.ByteTensor(
62
+ torch.ByteStorage.from_buffer(pic.tobytes()))
63
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
64
+ if pic.mode == 'YCbCr':
65
+ nchannel = 3
66
+ elif pic.mode == 'I;16':
67
+ nchannel = 1
68
+ else:
69
+ nchannel = len(pic.mode)
70
+ img = img.view(pic.size[1], pic.size[0], nchannel)
71
+
72
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
73
+ if isinstance(img, torch.ByteTensor):
74
+ return img.float()
75
+ else:
76
+ return img
77
+
78
+
79
+ class MarigoldNYU(Dataset):
80
+ def __init__(self, nyu_dir_root, marigold_depth_root, debug_mode=False):
81
+ import glob
82
+ import os
83
+ import itertools
84
+
85
+ categories = os.listdir(os.path.join(nyu_dir_root))
86
+ if debug_mode:
87
+ categories = categories[:2]
88
+
89
+ self.image_files = list(itertools.chain(*[glob.glob(os.path.join(nyu_dir_root, c, "rgb_*.jpg")) for c in categories]))
90
+ self.nyu_depth_files = [os.path.join(nyu_dir_root, os.path.join(*r.split("/")[-2:])).replace("jpg", "png").replace("rgb", "sync_depth") for r in self.image_files]
91
+ self.marigold_depth_files = [os.path.join(marigold_depth_root, os.path.join(*r.split("/")[-2:])).replace("jpg", "npy") for r in self.image_files]
92
+
93
+ self.transform = ToTensor()
94
+
95
+ def __getitem__(self, idx):
96
+ image_path = self.image_files[idx]
97
+ nyu_depth_path = self.nyu_depth_files[idx]
98
+ marigold_depth_path = self.marigold_depth_files[idx]
99
+
100
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
101
+ nyu_depth = np.asarray(Image.open(nyu_depth_path), dtype=np.float32)
102
+ marigold_depth = np.load(marigold_depth_path)
103
+
104
+ return image, nyu_depth[..., np.newaxis], marigold_depth[..., np.newaxis], image_path, nyu_depth_path
105
+
106
+ def __len__(self):
107
+ return len(self.image_files)
108
+
109
+
110
+ def get_marigold_nyu_loader(nyu_dir_root, marigold_depth_root, batch_size=1, **kwargs):
111
+ dataset = MarigoldNYU(nyu_dir_root, marigold_depth_root)
112
+ return DataLoader(dataset, batch_size, **kwargs)
zoedepth/data/places365.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+ from random import choice
33
+
34
+
35
+ class ToTensor(object):
36
+ def __init__(self):
37
+ self.normalize = transforms.Normalize(
38
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ #self.normalize = lambda x : x
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ return {'image': image, 'depth': depth, 'dataset': "places365"}
48
+
49
+ def to_tensor(self, pic):
50
+
51
+ if isinstance(pic, np.ndarray):
52
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
53
+ return img
54
+
55
+ # # handle PIL Image
56
+ if pic.mode == 'I':
57
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
58
+ elif pic.mode == 'I;16':
59
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
60
+ else:
61
+ img = torch.ByteTensor(
62
+ torch.ByteStorage.from_buffer(pic.tobytes()))
63
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
64
+ if pic.mode == 'YCbCr':
65
+ nchannel = 3
66
+ elif pic.mode == 'I;16':
67
+ nchannel = 1
68
+ else:
69
+ nchannel = len(pic.mode)
70
+ img = img.view(pic.size[1], pic.size[0], nchannel)
71
+
72
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
73
+ if isinstance(img, torch.ByteTensor):
74
+ return img.float()
75
+ else:
76
+ return img
77
+
78
+
79
+ class Places365(Dataset):
80
+ def __init__(self, data_dir_root, depth_dir_root, depth_masks_dir_root, randomize_masks=True, debug_mode=False):
81
+ import glob
82
+ import os
83
+ import itertools
84
+
85
+ categories = os.listdir(os.path.join(data_dir_root))
86
+ if debug_mode:
87
+ categories = categories[:2]
88
+
89
+ self.image_files = list(itertools.chain(*[glob.glob(os.path.join(data_dir_root, c, "*.jpg")) for c in categories]))
90
+ self.depth_files = [os.path.join(depth_dir_root, os.path.join(*r.split("/")[-2:])).replace("jpg", "npy") for r in self.image_files]
91
+ self.depth_masks_files = [os.path.join(depth_masks_dir_root, os.path.join(*r.split("/")[-2:])).replace("jpg", "npy") for r in self.image_files]
92
+
93
+ self.randomize_masks = randomize_masks
94
+
95
+ self.transform = ToTensor()
96
+
97
+ def __getitem__(self, idx):
98
+ image_path = self.image_files[idx]
99
+ depth_path = self.depth_files[idx]
100
+
101
+ if not self.randomize_masks:
102
+ depth_masks_path = self.depth_masks_files[idx]
103
+ else:
104
+ depth_masks_path = choice(self.depth_masks_files)
105
+
106
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
107
+ depth = np.load(depth_path)
108
+ depth_mask = 1 - np.load(depth_masks_path)
109
+
110
+ return image, depth[..., np.newaxis], depth_mask[..., np.newaxis], image_path, depth_path, depth_masks_path
111
+
112
+ def __len__(self):
113
+ return len(self.image_files)
114
+
115
+
116
+ def get_places365_loader(data_dir_root, depth_dir_root, depth_masks_dir_root, batch_size=1, **kwargs):
117
+ dataset = Places365(data_dir_root, depth_dir_root, depth_masks_dir_root)
118
+ return DataLoader(dataset, batch_size, **kwargs)
zoedepth/data/preprocess.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import numpy as np
26
+ from dataclasses import dataclass
27
+ from typing import Tuple, List
28
+
29
+ # dataclass to store the crop parameters
30
+ @dataclass
31
+ class CropParams:
32
+ top: int
33
+ bottom: int
34
+ left: int
35
+ right: int
36
+
37
+
38
+
39
+ def get_border_params(rgb_image, tolerance=0.1, cut_off=20, value=0, level_diff_threshold=5, channel_axis=-1, min_border=5) -> CropParams:
40
+ gray_image = np.mean(rgb_image, axis=channel_axis)
41
+ h, w = gray_image.shape
42
+
43
+
44
+ def num_value_pixels(arr):
45
+ return np.sum(np.abs(arr - value) < level_diff_threshold)
46
+
47
+ def is_above_tolerance(arr, total_pixels):
48
+ return (num_value_pixels(arr) / total_pixels) > tolerance
49
+
50
+ # Crop top border until number of value pixels become below tolerance
51
+ top = min_border
52
+ while is_above_tolerance(gray_image[top, :], w) and top < h-1:
53
+ top += 1
54
+ if top > cut_off:
55
+ break
56
+
57
+ # Crop bottom border until number of value pixels become below tolerance
58
+ bottom = h - min_border
59
+ while is_above_tolerance(gray_image[bottom, :], w) and bottom > 0:
60
+ bottom -= 1
61
+ if h - bottom > cut_off:
62
+ break
63
+
64
+ # Crop left border until number of value pixels become below tolerance
65
+ left = min_border
66
+ while is_above_tolerance(gray_image[:, left], h) and left < w-1:
67
+ left += 1
68
+ if left > cut_off:
69
+ break
70
+
71
+ # Crop right border until number of value pixels become below tolerance
72
+ right = w - min_border
73
+ while is_above_tolerance(gray_image[:, right], h) and right > 0:
74
+ right -= 1
75
+ if w - right > cut_off:
76
+ break
77
+
78
+
79
+ return CropParams(top, bottom, left, right)
80
+
81
+
82
+ def get_white_border(rgb_image, value=255, **kwargs) -> CropParams:
83
+ """Crops the white border of the RGB.
84
+
85
+ Args:
86
+ rgb: RGB image, shape (H, W, 3).
87
+ Returns:
88
+ Crop parameters.
89
+ """
90
+ if value == 255:
91
+ # assert range of values in rgb image is [0, 255]
92
+ assert np.max(rgb_image) <= 255 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 255]."
93
+ assert rgb_image.max() > 1, "RGB image values are not in range [0, 255]."
94
+ elif value == 1:
95
+ # assert range of values in rgb image is [0, 1]
96
+ assert np.max(rgb_image) <= 1 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 1]."
97
+
98
+ return get_border_params(rgb_image, value=value, **kwargs)
99
+
100
+ def get_black_border(rgb_image, **kwargs) -> CropParams:
101
+ """Crops the black border of the RGB.
102
+
103
+ Args:
104
+ rgb: RGB image, shape (H, W, 3).
105
+
106
+ Returns:
107
+ Crop parameters.
108
+ """
109
+
110
+ return get_border_params(rgb_image, value=0, **kwargs)
111
+
112
+ def crop_image(image: np.ndarray, crop_params: CropParams) -> np.ndarray:
113
+ """Crops the image according to the crop parameters.
114
+
115
+ Args:
116
+ image: RGB or depth image, shape (H, W, 3) or (H, W).
117
+ crop_params: Crop parameters.
118
+
119
+ Returns:
120
+ Cropped image.
121
+ """
122
+ return image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
123
+
124
+ def crop_images(*images: np.ndarray, crop_params: CropParams) -> Tuple[np.ndarray]:
125
+ """Crops the images according to the crop parameters.
126
+
127
+ Args:
128
+ images: RGB or depth images, shape (H, W, 3) or (H, W).
129
+ crop_params: Crop parameters.
130
+
131
+ Returns:
132
+ Cropped images.
133
+ """
134
+ return tuple(crop_image(image, crop_params) for image in images)
135
+
136
+ def crop_black_or_white_border(rgb_image, *other_images: np.ndarray, tolerance=0.1, cut_off=20, level_diff_threshold=5) -> Tuple[np.ndarray]:
137
+ """Crops the white and black border of the RGB and depth images.
138
+
139
+ Args:
140
+ rgb: RGB image, shape (H, W, 3). This image is used to determine the border.
141
+ other_images: The other images to crop according to the border of the RGB image.
142
+ Returns:
143
+ Cropped RGB and other images.
144
+ """
145
+ # crop black border
146
+ crop_params = get_black_border(rgb_image, tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold)
147
+ cropped_images = crop_images(rgb_image, *other_images, crop_params=crop_params)
148
+
149
+ # crop white border
150
+ crop_params = get_white_border(cropped_images[0], tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold)
151
+ cropped_images = crop_images(*cropped_images, crop_params=crop_params)
152
+
153
+ return cropped_images
154
+
zoedepth/data/sun_rgbd_loader.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+
40
+ def __call__(self, sample):
41
+ image, depth = sample['image'], sample['depth']
42
+ image = self.to_tensor(image)
43
+ image = self.normalize(image)
44
+ depth = self.to_tensor(depth)
45
+
46
+ return {'image': image, 'depth': depth, 'dataset': "sunrgbd"}
47
+
48
+ def to_tensor(self, pic):
49
+
50
+ if isinstance(pic, np.ndarray):
51
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
52
+ return img
53
+
54
+ # # handle PIL Image
55
+ if pic.mode == 'I':
56
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
57
+ elif pic.mode == 'I;16':
58
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
59
+ else:
60
+ img = torch.ByteTensor(
61
+ torch.ByteStorage.from_buffer(pic.tobytes()))
62
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
63
+ if pic.mode == 'YCbCr':
64
+ nchannel = 3
65
+ elif pic.mode == 'I;16':
66
+ nchannel = 1
67
+ else:
68
+ nchannel = len(pic.mode)
69
+ img = img.view(pic.size[1], pic.size[0], nchannel)
70
+
71
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
72
+ if isinstance(img, torch.ByteTensor):
73
+ return img.float()
74
+ else:
75
+ return img
76
+
77
+
78
+ class SunRGBD(Dataset):
79
+ def __init__(self, data_dir_root):
80
+ # test_file_dirs = loadmat(train_test_file)['alltest'].squeeze()
81
+ # all_test = [t[0].replace("/n/fs/sun3d/data/", "") for t in test_file_dirs]
82
+ # self.all_test = [os.path.join(data_dir_root, t) for t in all_test]
83
+ import glob
84
+ self.image_files = glob.glob(
85
+ os.path.join(data_dir_root, 'rgb', 'rgb', '*'))
86
+ self.depth_files = [
87
+ r.replace("rgb/rgb", "gt/gt").replace("jpg", "png") for r in self.image_files]
88
+ self.transform = ToTensor()
89
+
90
+ def __getitem__(self, idx):
91
+ image_path = self.image_files[idx]
92
+ depth_path = self.depth_files[idx]
93
+
94
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
95
+ depth = np.asarray(Image.open(depth_path), dtype='uint16') / 1000.0
96
+ depth[depth > 8] = -1
97
+ depth = depth[..., None]
98
+ return self.transform(dict(image=image, depth=depth))
99
+
100
+ def __len__(self):
101
+ return len(self.image_files)
102
+
103
+
104
+ def get_sunrgbd_loader(data_dir_root, batch_size=1, **kwargs):
105
+ dataset = SunRGBD(data_dir_root)
106
+ return DataLoader(dataset, batch_size, **kwargs)
zoedepth/data/transforms.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import math
26
+ import random
27
+
28
+ import cv2
29
+ import numpy as np
30
+
31
+
32
+ class RandomFliplr(object):
33
+ """Horizontal flip of the sample with given probability.
34
+ """
35
+
36
+ def __init__(self, probability=0.5):
37
+ """Init.
38
+
39
+ Args:
40
+ probability (float, optional): Flip probability. Defaults to 0.5.
41
+ """
42
+ self.__probability = probability
43
+
44
+ def __call__(self, sample):
45
+ prob = random.random()
46
+
47
+ if prob < self.__probability:
48
+ for k, v in sample.items():
49
+ if len(v.shape) >= 2:
50
+ sample[k] = np.fliplr(v).copy()
51
+
52
+ return sample
53
+
54
+
55
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
56
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
57
+
58
+ Args:
59
+ sample (dict): sample
60
+ size (tuple): image size
61
+
62
+ Returns:
63
+ tuple: new size
64
+ """
65
+ shape = list(sample["disparity"].shape)
66
+
67
+ if shape[0] >= size[0] and shape[1] >= size[1]:
68
+ return sample
69
+
70
+ scale = [0, 0]
71
+ scale[0] = size[0] / shape[0]
72
+ scale[1] = size[1] / shape[1]
73
+
74
+ scale = max(scale)
75
+
76
+ shape[0] = math.ceil(scale * shape[0])
77
+ shape[1] = math.ceil(scale * shape[1])
78
+
79
+ # resize
80
+ sample["image"] = cv2.resize(
81
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
82
+ )
83
+
84
+ sample["disparity"] = cv2.resize(
85
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
86
+ )
87
+ sample["mask"] = cv2.resize(
88
+ sample["mask"].astype(np.float32),
89
+ tuple(shape[::-1]),
90
+ interpolation=cv2.INTER_NEAREST,
91
+ )
92
+ sample["mask"] = sample["mask"].astype(bool)
93
+
94
+ return tuple(shape)
95
+
96
+
97
+ class RandomCrop(object):
98
+ """Get a random crop of the sample with the given size (width, height).
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ width,
104
+ height,
105
+ resize_if_needed=False,
106
+ image_interpolation_method=cv2.INTER_AREA,
107
+ ):
108
+ """Init.
109
+
110
+ Args:
111
+ width (int): output width
112
+ height (int): output height
113
+ resize_if_needed (bool, optional): If True, sample might be upsampled to ensure
114
+ that a crop of size (width, height) is possbile. Defaults to False.
115
+ """
116
+ self.__size = (height, width)
117
+ self.__resize_if_needed = resize_if_needed
118
+ self.__image_interpolation_method = image_interpolation_method
119
+
120
+ def __call__(self, sample):
121
+
122
+ shape = sample["disparity"].shape
123
+
124
+ if self.__size[0] > shape[0] or self.__size[1] > shape[1]:
125
+ if self.__resize_if_needed:
126
+ shape = apply_min_size(
127
+ sample, self.__size, self.__image_interpolation_method
128
+ )
129
+ else:
130
+ raise Exception(
131
+ "Output size {} bigger than input size {}.".format(
132
+ self.__size, shape
133
+ )
134
+ )
135
+
136
+ offset = (
137
+ np.random.randint(shape[0] - self.__size[0] + 1),
138
+ np.random.randint(shape[1] - self.__size[1] + 1),
139
+ )
140
+
141
+ for k, v in sample.items():
142
+ if k == "code" or k == "basis":
143
+ continue
144
+
145
+ if len(sample[k].shape) >= 2:
146
+ sample[k] = v[
147
+ offset[0]: offset[0] + self.__size[0],
148
+ offset[1]: offset[1] + self.__size[1],
149
+ ]
150
+
151
+ return sample
152
+
153
+
154
+ class Resize(object):
155
+ """Resize sample to given size (width, height).
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ width,
161
+ height,
162
+ resize_target=True,
163
+ keep_aspect_ratio=False,
164
+ ensure_multiple_of=1,
165
+ resize_method="lower_bound",
166
+ image_interpolation_method=cv2.INTER_AREA,
167
+ letter_box=False,
168
+ ):
169
+ """Init.
170
+
171
+ Args:
172
+ width (int): desired output width
173
+ height (int): desired output height
174
+ resize_target (bool, optional):
175
+ True: Resize the full sample (image, mask, target).
176
+ False: Resize image only.
177
+ Defaults to True.
178
+ keep_aspect_ratio (bool, optional):
179
+ True: Keep the aspect ratio of the input sample.
180
+ Output sample might not have the given width and height, and
181
+ resize behaviour depends on the parameter 'resize_method'.
182
+ Defaults to False.
183
+ ensure_multiple_of (int, optional):
184
+ Output width and height is constrained to be multiple of this parameter.
185
+ Defaults to 1.
186
+ resize_method (str, optional):
187
+ "lower_bound": Output will be at least as large as the given size.
188
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
189
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
190
+ Defaults to "lower_bound".
191
+ """
192
+ self.__width = width
193
+ self.__height = height
194
+
195
+ self.__resize_target = resize_target
196
+ self.__keep_aspect_ratio = keep_aspect_ratio
197
+ self.__multiple_of = ensure_multiple_of
198
+ self.__resize_method = resize_method
199
+ self.__image_interpolation_method = image_interpolation_method
200
+ self.__letter_box = letter_box
201
+
202
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
203
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
204
+
205
+ if max_val is not None and y > max_val:
206
+ y = (np.floor(x / self.__multiple_of)
207
+ * self.__multiple_of).astype(int)
208
+
209
+ if y < min_val:
210
+ y = (np.ceil(x / self.__multiple_of)
211
+ * self.__multiple_of).astype(int)
212
+
213
+ return y
214
+
215
+ def get_size(self, width, height):
216
+ # determine new height and width
217
+ scale_height = self.__height / height
218
+ scale_width = self.__width / width
219
+
220
+ if self.__keep_aspect_ratio:
221
+ if self.__resize_method == "lower_bound":
222
+ # scale such that output size is lower bound
223
+ if scale_width > scale_height:
224
+ # fit width
225
+ scale_height = scale_width
226
+ else:
227
+ # fit height
228
+ scale_width = scale_height
229
+ elif self.__resize_method == "upper_bound":
230
+ # scale such that output size is upper bound
231
+ if scale_width < scale_height:
232
+ # fit width
233
+ scale_height = scale_width
234
+ else:
235
+ # fit height
236
+ scale_width = scale_height
237
+ elif self.__resize_method == "minimal":
238
+ # scale as least as possbile
239
+ if abs(1 - scale_width) < abs(1 - scale_height):
240
+ # fit width
241
+ scale_height = scale_width
242
+ else:
243
+ # fit height
244
+ scale_width = scale_height
245
+ else:
246
+ raise ValueError(
247
+ f"resize_method {self.__resize_method} not implemented"
248
+ )
249
+
250
+ if self.__resize_method == "lower_bound":
251
+ new_height = self.constrain_to_multiple_of(
252
+ scale_height * height, min_val=self.__height
253
+ )
254
+ new_width = self.constrain_to_multiple_of(
255
+ scale_width * width, min_val=self.__width
256
+ )
257
+ elif self.__resize_method == "upper_bound":
258
+ new_height = self.constrain_to_multiple_of(
259
+ scale_height * height, max_val=self.__height
260
+ )
261
+ new_width = self.constrain_to_multiple_of(
262
+ scale_width * width, max_val=self.__width
263
+ )
264
+ elif self.__resize_method == "minimal":
265
+ new_height = self.constrain_to_multiple_of(scale_height * height)
266
+ new_width = self.constrain_to_multiple_of(scale_width * width)
267
+ else:
268
+ raise ValueError(
269
+ f"resize_method {self.__resize_method} not implemented")
270
+
271
+ return (new_width, new_height)
272
+
273
+ def make_letter_box(self, sample):
274
+ top = bottom = (self.__height - sample.shape[0]) // 2
275
+ left = right = (self.__width - sample.shape[1]) // 2
276
+ sample = cv2.copyMakeBorder(
277
+ sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0)
278
+ return sample
279
+
280
+ def __call__(self, sample):
281
+ width, height = self.get_size(
282
+ sample["image"].shape[1], sample["image"].shape[0]
283
+ )
284
+
285
+ # resize sample
286
+ sample["image"] = cv2.resize(
287
+ sample["image"],
288
+ (width, height),
289
+ interpolation=self.__image_interpolation_method,
290
+ )
291
+
292
+ if self.__letter_box:
293
+ sample["image"] = self.make_letter_box(sample["image"])
294
+
295
+ if self.__resize_target:
296
+ if "disparity" in sample:
297
+ sample["disparity"] = cv2.resize(
298
+ sample["disparity"],
299
+ (width, height),
300
+ interpolation=cv2.INTER_NEAREST,
301
+ )
302
+
303
+ if self.__letter_box:
304
+ sample["disparity"] = self.make_letter_box(
305
+ sample["disparity"])
306
+
307
+ if "depth" in sample:
308
+ sample["depth"] = cv2.resize(
309
+ sample["depth"], (width,
310
+ height), interpolation=cv2.INTER_NEAREST
311
+ )
312
+
313
+ if self.__letter_box:
314
+ sample["depth"] = self.make_letter_box(sample["depth"])
315
+
316
+ sample["mask"] = cv2.resize(
317
+ sample["mask"].astype(np.float32),
318
+ (width, height),
319
+ interpolation=cv2.INTER_NEAREST,
320
+ )
321
+
322
+ if self.__letter_box:
323
+ sample["mask"] = self.make_letter_box(sample["mask"])
324
+
325
+ sample["mask"] = sample["mask"].astype(bool)
326
+
327
+ return sample
328
+
329
+
330
+ class ResizeFixed(object):
331
+ def __init__(self, size):
332
+ self.__size = size
333
+
334
+ def __call__(self, sample):
335
+ sample["image"] = cv2.resize(
336
+ sample["image"], self.__size[::-1], interpolation=cv2.INTER_LINEAR
337
+ )
338
+
339
+ sample["disparity"] = cv2.resize(
340
+ sample["disparity"], self.__size[::-
341
+ 1], interpolation=cv2.INTER_NEAREST
342
+ )
343
+
344
+ sample["mask"] = cv2.resize(
345
+ sample["mask"].astype(np.float32),
346
+ self.__size[::-1],
347
+ interpolation=cv2.INTER_NEAREST,
348
+ )
349
+ sample["mask"] = sample["mask"].astype(bool)
350
+
351
+ return sample
352
+
353
+
354
+ class Rescale(object):
355
+ """Rescale target values to the interval [0, max_val].
356
+ If input is constant, values are set to max_val / 2.
357
+ """
358
+
359
+ def __init__(self, max_val=1.0, use_mask=True):
360
+ """Init.
361
+
362
+ Args:
363
+ max_val (float, optional): Max output value. Defaults to 1.0.
364
+ use_mask (bool, optional): Only operate on valid pixels (mask == True). Defaults to True.
365
+ """
366
+ self.__max_val = max_val
367
+ self.__use_mask = use_mask
368
+
369
+ def __call__(self, sample):
370
+ disp = sample["disparity"]
371
+
372
+ if self.__use_mask:
373
+ mask = sample["mask"]
374
+ else:
375
+ mask = np.ones_like(disp, dtype=np.bool)
376
+
377
+ if np.sum(mask) == 0:
378
+ return sample
379
+
380
+ min_val = np.min(disp[mask])
381
+ max_val = np.max(disp[mask])
382
+
383
+ if max_val > min_val:
384
+ sample["disparity"][mask] = (
385
+ (disp[mask] - min_val) / (max_val - min_val) * self.__max_val
386
+ )
387
+ else:
388
+ sample["disparity"][mask] = np.ones_like(
389
+ disp[mask]) * self.__max_val / 2.0
390
+
391
+ return sample
392
+
393
+
394
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
395
+ class NormalizeImage(object):
396
+ """Normlize image by given mean and std.
397
+ """
398
+
399
+ def __init__(self, mean, std):
400
+ self.__mean = mean
401
+ self.__std = std
402
+
403
+ def __call__(self, sample):
404
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
405
+
406
+ return sample
407
+
408
+
409
+ class DepthToDisparity(object):
410
+ """Convert depth to disparity. Removes depth from sample.
411
+ """
412
+
413
+ def __init__(self, eps=1e-4):
414
+ self.__eps = eps
415
+
416
+ def __call__(self, sample):
417
+ assert "depth" in sample
418
+
419
+ sample["mask"][sample["depth"] < self.__eps] = False
420
+
421
+ sample["disparity"] = np.zeros_like(sample["depth"])
422
+ sample["disparity"][sample["depth"] >= self.__eps] = (
423
+ 1.0 / sample["depth"][sample["depth"] >= self.__eps]
424
+ )
425
+
426
+ del sample["depth"]
427
+
428
+ return sample
429
+
430
+
431
+ class DisparityToDepth(object):
432
+ """Convert disparity to depth. Removes disparity from sample.
433
+ """
434
+
435
+ def __init__(self, eps=1e-4):
436
+ self.__eps = eps
437
+
438
+ def __call__(self, sample):
439
+ assert "disparity" in sample
440
+
441
+ disp = np.abs(sample["disparity"])
442
+ sample["mask"][disp < self.__eps] = False
443
+
444
+ # print(sample["disparity"])
445
+ # print(sample["mask"].sum())
446
+ # exit()
447
+
448
+ sample["depth"] = np.zeros_like(disp)
449
+ sample["depth"][disp >= self.__eps] = (
450
+ 1.0 / disp[disp >= self.__eps]
451
+ )
452
+
453
+ del sample["disparity"]
454
+
455
+ return sample
456
+
457
+
458
+ class PrepareForNet(object):
459
+ """Prepare sample for usage as network input.
460
+ """
461
+
462
+ def __init__(self):
463
+ pass
464
+
465
+ def __call__(self, sample):
466
+ image = np.transpose(sample["image"], (2, 0, 1))
467
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
468
+
469
+ if "mask" in sample:
470
+ sample["mask"] = sample["mask"].astype(np.float32)
471
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
472
+
473
+ if "disparity" in sample:
474
+ disparity = sample["disparity"].astype(np.float32)
475
+ sample["disparity"] = np.ascontiguousarray(disparity)
476
+
477
+ if "depth" in sample:
478
+ depth = sample["depth"].astype(np.float32)
479
+ sample["depth"] = np.ascontiguousarray(depth)
480
+
481
+ return sample
zoedepth/data/vkitti.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ from torch.utils.data import Dataset, DataLoader
27
+ from torchvision import transforms
28
+ import os
29
+
30
+ from PIL import Image
31
+ import numpy as np
32
+ import cv2
33
+
34
+
35
+ class ToTensor(object):
36
+ def __init__(self):
37
+ self.normalize = transforms.Normalize(
38
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ # self.resize = transforms.Resize((375, 1242))
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+
44
+ image = self.to_tensor(image)
45
+ image = self.normalize(image)
46
+ depth = self.to_tensor(depth)
47
+
48
+ # image = self.resize(image)
49
+
50
+ return {'image': image, 'depth': depth, 'dataset': "vkitti"}
51
+
52
+ def to_tensor(self, pic):
53
+
54
+ if isinstance(pic, np.ndarray):
55
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
56
+ return img
57
+
58
+ # # handle PIL Image
59
+ if pic.mode == 'I':
60
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
61
+ elif pic.mode == 'I;16':
62
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
63
+ else:
64
+ img = torch.ByteTensor(
65
+ torch.ByteStorage.from_buffer(pic.tobytes()))
66
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
67
+ if pic.mode == 'YCbCr':
68
+ nchannel = 3
69
+ elif pic.mode == 'I;16':
70
+ nchannel = 1
71
+ else:
72
+ nchannel = len(pic.mode)
73
+ img = img.view(pic.size[1], pic.size[0], nchannel)
74
+
75
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
76
+ if isinstance(img, torch.ByteTensor):
77
+ return img.float()
78
+ else:
79
+ return img
80
+
81
+
82
+ class VKITTI(Dataset):
83
+ def __init__(self, data_dir_root, do_kb_crop=True):
84
+ import glob
85
+ # image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
86
+ self.image_files = glob.glob(os.path.join(
87
+ data_dir_root, "test_color", '*.png'))
88
+ self.depth_files = [r.replace("test_color", "test_depth")
89
+ for r in self.image_files]
90
+ self.do_kb_crop = True
91
+ self.transform = ToTensor()
92
+
93
+ def __getitem__(self, idx):
94
+ image_path = self.image_files[idx]
95
+ depth_path = self.depth_files[idx]
96
+
97
+ image = Image.open(image_path)
98
+ depth = Image.open(depth_path)
99
+ depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR |
100
+ cv2.IMREAD_ANYDEPTH)
101
+ print("dpeth min max", depth.min(), depth.max())
102
+
103
+ # print(np.shape(image))
104
+ # print(np.shape(depth))
105
+
106
+ # depth[depth > 8] = -1
107
+
108
+ if self.do_kb_crop and False:
109
+ height = image.height
110
+ width = image.width
111
+ top_margin = int(height - 352)
112
+ left_margin = int((width - 1216) / 2)
113
+ depth = depth.crop(
114
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
115
+ image = image.crop(
116
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
117
+ # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216]
118
+
119
+ image = np.asarray(image, dtype=np.float32) / 255.0
120
+ # depth = np.asarray(depth, dtype=np.uint16) /1.
121
+ depth = depth[..., None]
122
+ sample = dict(image=image, depth=depth)
123
+
124
+ # return sample
125
+ sample = self.transform(sample)
126
+
127
+ if idx == 0:
128
+ print(sample["image"].shape)
129
+
130
+ return sample
131
+
132
+ def __len__(self):
133
+ return len(self.image_files)
134
+
135
+
136
+ def get_vkitti_loader(data_dir_root, batch_size=1, **kwargs):
137
+ dataset = VKITTI(data_dir_root)
138
+ return DataLoader(dataset, batch_size, **kwargs)
139
+
140
+
141
+ if __name__ == "__main__":
142
+ loader = get_vkitti_loader(
143
+ data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti_test")
144
+ print("Total files", len(loader.dataset))
145
+ for i, sample in enumerate(loader):
146
+ print(sample["image"].shape)
147
+ print(sample["depth"].shape)
148
+ print(sample["dataset"])
149
+ print(sample['depth'].min(), sample['depth'].max())
150
+ if i > 5:
151
+ break
zoedepth/data/vkitti2.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import cv2
28
+ import numpy as np
29
+ import torch
30
+ from PIL import Image
31
+ from torch.utils.data import DataLoader, Dataset
32
+ from torchvision import transforms
33
+
34
+
35
+ class ToTensor(object):
36
+ def __init__(self):
37
+ # self.normalize = transforms.Normalize(
38
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ self.normalize = lambda x: x
40
+ # self.resize = transforms.Resize((375, 1242))
41
+
42
+ def __call__(self, sample):
43
+ image, depth = sample['image'], sample['depth']
44
+
45
+ image = self.to_tensor(image)
46
+ image = self.normalize(image)
47
+ depth = self.to_tensor(depth)
48
+
49
+ # image = self.resize(image)
50
+
51
+ return {'image': image, 'depth': depth, 'dataset': "vkitti"}
52
+
53
+ def to_tensor(self, pic):
54
+
55
+ if isinstance(pic, np.ndarray):
56
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
57
+ return img
58
+
59
+ # # handle PIL Image
60
+ if pic.mode == 'I':
61
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
62
+ elif pic.mode == 'I;16':
63
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
64
+ else:
65
+ img = torch.ByteTensor(
66
+ torch.ByteStorage.from_buffer(pic.tobytes()))
67
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
68
+ if pic.mode == 'YCbCr':
69
+ nchannel = 3
70
+ elif pic.mode == 'I;16':
71
+ nchannel = 1
72
+ else:
73
+ nchannel = len(pic.mode)
74
+ img = img.view(pic.size[1], pic.size[0], nchannel)
75
+
76
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
77
+ if isinstance(img, torch.ByteTensor):
78
+ return img.float()
79
+ else:
80
+ return img
81
+
82
+
83
+ class VKITTI2(Dataset):
84
+ def __init__(self, data_dir_root, do_kb_crop=True, split="test"):
85
+ import glob
86
+
87
+ # image paths are of the form <data_dir_root>/rgb/<scene>/<variant>/frames/<rgb,depth>/Camera<0,1>/rgb_{}.jpg
88
+ self.image_files = glob.glob(os.path.join(
89
+ data_dir_root, "rgb", "**", "frames", "rgb", "Camera_0", '*.jpg'), recursive=True)
90
+ self.depth_files = [r.replace("/rgb/", "/depth/").replace(
91
+ "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
92
+ self.do_kb_crop = True
93
+ self.transform = ToTensor()
94
+
95
+ # If train test split is not created, then create one.
96
+ # Split is such that 8% of the frames from each scene are used for testing.
97
+ if not os.path.exists(os.path.join(data_dir_root, "train.txt")):
98
+ import random
99
+ scenes = set([os.path.basename(os.path.dirname(
100
+ os.path.dirname(os.path.dirname(f)))) for f in self.image_files])
101
+ train_files = []
102
+ test_files = []
103
+ for scene in scenes:
104
+ scene_files = [f for f in self.image_files if os.path.basename(
105
+ os.path.dirname(os.path.dirname(os.path.dirname(f)))) == scene]
106
+ random.shuffle(scene_files)
107
+ train_files.extend(scene_files[:int(len(scene_files) * 0.92)])
108
+ test_files.extend(scene_files[int(len(scene_files) * 0.92):])
109
+ with open(os.path.join(data_dir_root, "train.txt"), "w") as f:
110
+ f.write("\n".join(train_files))
111
+ with open(os.path.join(data_dir_root, "test.txt"), "w") as f:
112
+ f.write("\n".join(test_files))
113
+
114
+ if split == "train":
115
+ with open(os.path.join(data_dir_root, "train.txt"), "r") as f:
116
+ self.image_files = f.read().splitlines()
117
+ self.depth_files = [r.replace("/rgb/", "/depth/").replace(
118
+ "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
119
+ elif split == "test":
120
+ with open(os.path.join(data_dir_root, "test.txt"), "r") as f:
121
+ self.image_files = f.read().splitlines()
122
+ self.depth_files = [r.replace("/rgb/", "/depth/").replace(
123
+ "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
124
+
125
+ def __getitem__(self, idx):
126
+ image_path = self.image_files[idx]
127
+ depth_path = self.depth_files[idx]
128
+
129
+ image = Image.open(image_path)
130
+ # depth = Image.open(depth_path)
131
+ depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR |
132
+ cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m
133
+ depth = Image.fromarray(depth)
134
+ # print("dpeth min max", depth.min(), depth.max())
135
+
136
+ # print(np.shape(image))
137
+ # print(np.shape(depth))
138
+
139
+ if self.do_kb_crop:
140
+ if idx == 0:
141
+ print("Using KB input crop")
142
+ height = image.height
143
+ width = image.width
144
+ top_margin = int(height - 352)
145
+ left_margin = int((width - 1216) / 2)
146
+ depth = depth.crop(
147
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
148
+ image = image.crop(
149
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
150
+ # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216]
151
+
152
+ image = np.asarray(image, dtype=np.float32) / 255.0
153
+ # depth = np.asarray(depth, dtype=np.uint16) /1.
154
+ depth = np.asarray(depth, dtype=np.float32) / 1.
155
+ depth[depth > 80] = -1
156
+
157
+ depth = depth[..., None]
158
+ sample = dict(image=image, depth=depth)
159
+
160
+ # return sample
161
+ sample = self.transform(sample)
162
+
163
+ if idx == 0:
164
+ print(sample["image"].shape)
165
+
166
+ return sample
167
+
168
+ def __len__(self):
169
+ return len(self.image_files)
170
+
171
+
172
+ def get_vkitti2_loader(data_dir_root, batch_size=1, **kwargs):
173
+ dataset = VKITTI2(data_dir_root)
174
+ return DataLoader(dataset, batch_size, **kwargs)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ loader = get_vkitti2_loader(
179
+ data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti2")
180
+ print("Total files", len(loader.dataset))
181
+ for i, sample in enumerate(loader):
182
+ print(sample["image"].shape)
183
+ print(sample["depth"].shape)
184
+ print(sample["dataset"])
185
+ print(sample['depth'].min(), sample['depth'].max())
186
+ if i > 5:
187
+ break
zoedepth/models/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+