CharlieAmalet commited on
Commit
5baac47
·
verified ·
1 Parent(s): ea9a6b2

Create marigold_depth_estimation.py

Browse files
Files changed (1) hide show
  1. marigold_depth_estimation.py +619 -0
marigold_depth_estimation.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import math
22
+ from typing import Dict, Union
23
+
24
+ import matplotlib
25
+ import numpy as np
26
+ import torch
27
+ from PIL import Image
28
+ from scipy.optimize import minimize
29
+ from torch.utils.data import DataLoader, TensorDataset
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from diffusers import (
34
+ AutoencoderKL,
35
+ DDIMScheduler,
36
+ DiffusionPipeline,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.utils import BaseOutput, check_min_version
40
+
41
+
42
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43
+ check_min_version("0.27.0.dev0")
44
+
45
+
46
+ class MarigoldDepthOutput(BaseOutput):
47
+ """
48
+ Output class for Marigold monocular depth prediction pipeline.
49
+ Args:
50
+ depth_np (`np.ndarray`):
51
+ Predicted depth map, with depth values in the range of [0, 1].
52
+ depth_colored (`None` or `PIL.Image.Image`):
53
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
54
+ uncertainty (`None` or `np.ndarray`):
55
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
56
+ """
57
+
58
+ depth_np: np.ndarray
59
+ depth_colored: Union[None, Image.Image]
60
+ uncertainty: Union[None, np.ndarray]
61
+
62
+
63
+ class MarigoldPipeline(DiffusionPipeline):
64
+ """
65
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
66
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
67
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
68
+ Args:
69
+ unet (`UNet2DConditionModel`):
70
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
71
+ vae (`AutoencoderKL`):
72
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
73
+ to and from latent representations.
74
+ scheduler (`DDIMScheduler`):
75
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
76
+ text_encoder (`CLIPTextModel`):
77
+ Text-encoder, for empty text embedding.
78
+ tokenizer (`CLIPTokenizer`):
79
+ CLIP tokenizer.
80
+ """
81
+
82
+ rgb_latent_scale_factor = 0.18215
83
+ depth_latent_scale_factor = 0.18215
84
+
85
+ def __init__(
86
+ self,
87
+ unet: UNet2DConditionModel,
88
+ vae: AutoencoderKL,
89
+ scheduler: DDIMScheduler,
90
+ text_encoder: CLIPTextModel,
91
+ tokenizer: CLIPTokenizer,
92
+ ):
93
+ super().__init__()
94
+
95
+ self.register_modules(
96
+ unet=unet,
97
+ vae=vae,
98
+ scheduler=scheduler,
99
+ text_encoder=text_encoder,
100
+ tokenizer=tokenizer,
101
+ )
102
+
103
+ self.empty_text_embed = None
104
+
105
+ @torch.no_grad()
106
+ def __call__(
107
+ self,
108
+ input_image: Image,
109
+ denoising_steps: int = 10,
110
+ ensemble_size: int = 10,
111
+ processing_res: int = 768,
112
+ match_input_res: bool = True,
113
+ batch_size: int = 0,
114
+ color_map: str = "Spectral",
115
+ show_progress_bar: bool = True,
116
+ ensemble_kwargs: Dict = None,
117
+ ) -> MarigoldDepthOutput:
118
+ """
119
+ Function invoked when calling the pipeline.
120
+ Args:
121
+ input_image (`Image`):
122
+ Input RGB (or gray-scale) image.
123
+ processing_res (`int`, *optional*, defaults to `768`):
124
+ Maximum resolution of processing.
125
+ If set to 0: will not resize at all.
126
+ match_input_res (`bool`, *optional*, defaults to `True`):
127
+ Resize depth prediction to match input resolution.
128
+ Only valid if `limit_input_res` is not None.
129
+ denoising_steps (`int`, *optional*, defaults to `10`):
130
+ Number of diffusion denoising steps (DDIM) during inference.
131
+ ensemble_size (`int`, *optional*, defaults to `10`):
132
+ Number of predictions to be ensembled.
133
+ batch_size (`int`, *optional*, defaults to `0`):
134
+ Inference batch size, no bigger than `num_ensemble`.
135
+ If set to 0, the script will automatically decide the proper batch size.
136
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
137
+ Display a progress bar of diffusion denoising.
138
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
139
+ Colormap used to colorize the depth map.
140
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
141
+ Arguments for detailed ensembling settings.
142
+ Returns:
143
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
144
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
145
+ - **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
146
+ values in [0, 1]. None if `color_map` is `None`
147
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
148
+ coming from ensembling. None if `ensemble_size = 1`
149
+ """
150
+
151
+ device = self.device
152
+ input_size = input_image.size
153
+
154
+ if not match_input_res:
155
+ assert (
156
+ processing_res is not None
157
+ ), "Value error: `resize_output_back` is only valid with "
158
+ assert processing_res >= 0
159
+ assert denoising_steps >= 1
160
+ assert ensemble_size >= 1
161
+
162
+ # ----------------- Image Preprocess -----------------
163
+ # Resize image
164
+ if processing_res > 0:
165
+ input_image = self.resize_max_res(
166
+ input_image, max_edge_resolution=processing_res
167
+ )
168
+ # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
169
+ input_image = input_image.convert("RGB")
170
+ image = np.asarray(input_image)
171
+
172
+ # Normalize rgb values
173
+ rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
174
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
175
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
176
+ rgb_norm = rgb_norm.to(device)
177
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
178
+
179
+ # ----------------- Predicting depth -----------------
180
+ # Batch repeated input image
181
+ duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
182
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
183
+ if batch_size > 0:
184
+ _bs = batch_size
185
+ else:
186
+ _bs = self._find_batch_size(
187
+ ensemble_size=ensemble_size,
188
+ input_res=max(rgb_norm.shape[1:]),
189
+ dtype=self.dtype,
190
+ )
191
+
192
+ single_rgb_loader = DataLoader(
193
+ single_rgb_dataset, batch_size=_bs, shuffle=False
194
+ )
195
+
196
+ # Predict depth maps (batched)
197
+ depth_pred_ls = []
198
+ if show_progress_bar:
199
+ iterable = tqdm(
200
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
201
+ )
202
+ else:
203
+ iterable = single_rgb_loader
204
+ for batch in iterable:
205
+ (batched_img,) = batch
206
+ depth_pred_raw = self.single_infer(
207
+ rgb_in=batched_img,
208
+ num_inference_steps=denoising_steps,
209
+ show_pbar=show_progress_bar,
210
+ )
211
+ depth_pred_ls.append(depth_pred_raw.detach().clone())
212
+ depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
213
+ torch.cuda.empty_cache() # clear vram cache for ensembling
214
+
215
+ # ----------------- Test-time ensembling -----------------
216
+ if ensemble_size > 1:
217
+ depth_pred, pred_uncert = self.ensemble_depths(
218
+ depth_preds, **(ensemble_kwargs or {})
219
+ )
220
+ else:
221
+ depth_pred = depth_preds
222
+ pred_uncert = None
223
+
224
+ # ----------------- Post processing -----------------
225
+ # Scale prediction to [0, 1]
226
+ min_d = torch.min(depth_pred)
227
+ max_d = torch.max(depth_pred)
228
+ depth_pred = (depth_pred - min_d) / (max_d - min_d)
229
+
230
+ # Convert to numpy
231
+ depth_pred = depth_pred.cpu().numpy().astype(np.float32)
232
+
233
+ # Resize back to original resolution
234
+ if match_input_res:
235
+ pred_img = Image.fromarray(depth_pred)
236
+ pred_img = pred_img.resize(input_size)
237
+ depth_pred = np.asarray(pred_img)
238
+
239
+ # Clip output range
240
+ depth_pred = depth_pred.clip(0, 1)
241
+
242
+ # Colorize
243
+ if color_map is not None:
244
+ depth_colored = self.colorize_depth_maps(
245
+ depth_pred, 0, 1, cmap=color_map
246
+ ).squeeze() # [3, H, W], value in (0, 1)
247
+ depth_colored = (depth_colored * 255).astype(np.uint8)
248
+ depth_colored_hwc = self.chw2hwc(depth_colored)
249
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
250
+ else:
251
+ depth_colored_img = None
252
+ return MarigoldDepthOutput(
253
+ depth_np=depth_pred,
254
+ depth_colored=depth_colored_img,
255
+ uncertainty=pred_uncert,
256
+ )
257
+
258
+ def _encode_empty_text(self):
259
+ """
260
+ Encode text embedding for empty prompt.
261
+ """
262
+ prompt = ""
263
+ text_inputs = self.tokenizer(
264
+ prompt,
265
+ padding="do_not_pad",
266
+ max_length=self.tokenizer.model_max_length,
267
+ truncation=True,
268
+ return_tensors="pt",
269
+ )
270
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
271
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
272
+
273
+ @torch.no_grad()
274
+ def single_infer(
275
+ self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool
276
+ ) -> torch.Tensor:
277
+ """
278
+ Perform an individual depth prediction without ensembling.
279
+ Args:
280
+ rgb_in (`torch.Tensor`):
281
+ Input RGB image.
282
+ num_inference_steps (`int`):
283
+ Number of diffusion denoisign steps (DDIM) during inference.
284
+ show_pbar (`bool`):
285
+ Display a progress bar of diffusion denoising.
286
+ Returns:
287
+ `torch.Tensor`: Predicted depth map.
288
+ """
289
+ device = rgb_in.device
290
+
291
+ # Set timesteps
292
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
293
+ timesteps = self.scheduler.timesteps # [T]
294
+
295
+ # Encode image
296
+ rgb_latent = self._encode_rgb(rgb_in)
297
+
298
+ # Initial depth map (noise)
299
+ depth_latent = torch.randn(
300
+ rgb_latent.shape, device=device, dtype=self.dtype
301
+ ) # [B, 4, h, w]
302
+
303
+ # Batched empty text embedding
304
+ if self.empty_text_embed is None:
305
+ self._encode_empty_text()
306
+ batch_empty_text_embed = self.empty_text_embed.repeat(
307
+ (rgb_latent.shape[0], 1, 1)
308
+ ) # [B, 2, 1024]
309
+
310
+ # Denoising loop
311
+ if show_pbar:
312
+ iterable = tqdm(
313
+ enumerate(timesteps),
314
+ total=len(timesteps),
315
+ leave=False,
316
+ desc=" " * 4 + "Diffusion denoising",
317
+ )
318
+ else:
319
+ iterable = enumerate(timesteps)
320
+
321
+ for i, t in iterable:
322
+ unet_input = torch.cat(
323
+ [rgb_latent, depth_latent], dim=1
324
+ ) # this order is important
325
+
326
+ # predict the noise residual
327
+ noise_pred = self.unet(
328
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
329
+ ).sample # [B, 4, h, w]
330
+
331
+ # compute the previous noisy sample x_t -> x_t-1
332
+ depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
333
+ torch.cuda.empty_cache()
334
+ depth = self._decode_depth(depth_latent)
335
+
336
+ # clip prediction
337
+ depth = torch.clip(depth, -1.0, 1.0)
338
+ # shift to [0, 1]
339
+ depth = (depth + 1.0) / 2.0
340
+
341
+ return depth
342
+
343
+ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
344
+ """
345
+ Encode RGB image into latent.
346
+ Args:
347
+ rgb_in (`torch.Tensor`):
348
+ Input RGB image to be encoded.
349
+ Returns:
350
+ `torch.Tensor`: Image latent.
351
+ """
352
+ # encode
353
+ h = self.vae.encoder(rgb_in)
354
+ moments = self.vae.quant_conv(h)
355
+ mean, logvar = torch.chunk(moments, 2, dim=1)
356
+ # scale latent
357
+ rgb_latent = mean * self.rgb_latent_scale_factor
358
+ return rgb_latent
359
+
360
+ def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
361
+ """
362
+ Decode depth latent into depth map.
363
+ Args:
364
+ depth_latent (`torch.Tensor`):
365
+ Depth latent to be decoded.
366
+ Returns:
367
+ `torch.Tensor`: Decoded depth map.
368
+ """
369
+ # scale latent
370
+ depth_latent = depth_latent / self.depth_latent_scale_factor
371
+ # decode
372
+ z = self.vae.post_quant_conv(depth_latent)
373
+ stacked = self.vae.decoder(z)
374
+ # mean of output channels
375
+ depth_mean = stacked.mean(dim=1, keepdim=True)
376
+ return depth_mean
377
+
378
+ @staticmethod
379
+ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
380
+ """
381
+ Resize image to limit maximum edge length while keeping aspect ratio.
382
+ Args:
383
+ img (`Image.Image`):
384
+ Image to be resized.
385
+ max_edge_resolution (`int`):
386
+ Maximum edge length (pixel).
387
+ Returns:
388
+ `Image.Image`: Resized image.
389
+ """
390
+ original_width, original_height = img.size
391
+ downscale_factor = min(
392
+ max_edge_resolution / original_width, max_edge_resolution / original_height
393
+ )
394
+
395
+ new_width = int(original_width * downscale_factor)
396
+ new_height = int(original_height * downscale_factor)
397
+
398
+ resized_img = img.resize((new_width, new_height))
399
+ return resized_img
400
+
401
+ @staticmethod
402
+ def colorize_depth_maps(
403
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
404
+ ):
405
+ """
406
+ Colorize depth maps.
407
+ """
408
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
409
+
410
+ if isinstance(depth_map, torch.Tensor):
411
+ depth = depth_map.detach().clone().squeeze().numpy()
412
+ elif isinstance(depth_map, np.ndarray):
413
+ depth = depth_map.copy().squeeze()
414
+ # reshape to [ (B,) H, W ]
415
+ if depth.ndim < 3:
416
+ depth = depth[np.newaxis, :, :]
417
+
418
+ # colorize
419
+ cm = matplotlib.colormaps[cmap]
420
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
421
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
422
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
423
+
424
+ if valid_mask is not None:
425
+ if isinstance(depth_map, torch.Tensor):
426
+ valid_mask = valid_mask.detach().numpy()
427
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
428
+ if valid_mask.ndim < 3:
429
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
430
+ else:
431
+ valid_mask = valid_mask[:, np.newaxis, :, :]
432
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
433
+ img_colored_np[~valid_mask] = 0
434
+
435
+ if isinstance(depth_map, torch.Tensor):
436
+ img_colored = torch.from_numpy(img_colored_np).float()
437
+ elif isinstance(depth_map, np.ndarray):
438
+ img_colored = img_colored_np
439
+
440
+ return img_colored
441
+
442
+ @staticmethod
443
+ def chw2hwc(chw):
444
+ assert 3 == len(chw.shape)
445
+ if isinstance(chw, torch.Tensor):
446
+ hwc = torch.permute(chw, (1, 2, 0))
447
+ elif isinstance(chw, np.ndarray):
448
+ hwc = np.moveaxis(chw, 0, -1)
449
+ return hwc
450
+
451
+ @staticmethod
452
+ def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
453
+ """
454
+ Automatically search for suitable operating batch size.
455
+ Args:
456
+ ensemble_size (`int`):
457
+ Number of predictions to be ensembled.
458
+ input_res (`int`):
459
+ Operating resolution of the input image.
460
+ Returns:
461
+ `int`: Operating batch size.
462
+ """
463
+ # Search table for suggested max. inference batch size
464
+ bs_search_table = [
465
+ # tested on A100-PCIE-80GB
466
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
467
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
468
+ # tested on A100-PCIE-40GB
469
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
470
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
471
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
472
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
473
+ # tested on RTX3090, RTX4090
474
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
475
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
476
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
477
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
478
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
479
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
480
+ # tested on GTX1080Ti
481
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
482
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
483
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
484
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
485
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
486
+ ]
487
+
488
+ if not torch.cuda.is_available():
489
+ return 1
490
+
491
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
492
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
493
+ for settings in sorted(
494
+ filtered_bs_search_table,
495
+ key=lambda k: (k["res"], -k["total_vram"]),
496
+ ):
497
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
498
+ bs = settings["bs"]
499
+ if bs > ensemble_size:
500
+ bs = ensemble_size
501
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
502
+ bs = math.ceil(ensemble_size / 2)
503
+ return bs
504
+
505
+ return 1
506
+
507
+ @staticmethod
508
+ def ensemble_depths(
509
+ input_images: torch.Tensor,
510
+ regularizer_strength: float = 0.02,
511
+ max_iter: int = 2,
512
+ tol: float = 1e-3,
513
+ reduction: str = "median",
514
+ max_res: int = None,
515
+ ):
516
+ """
517
+ To ensemble multiple affine-invariant depth images (up to scale and shift),
518
+ by aligning estimating the scale and shift
519
+ """
520
+
521
+ def inter_distances(tensors: torch.Tensor):
522
+ """
523
+ To calculate the distance between each two depth maps.
524
+ """
525
+ distances = []
526
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
527
+ arr1 = tensors[i : i + 1]
528
+ arr2 = tensors[j : j + 1]
529
+ distances.append(arr1 - arr2)
530
+ dist = torch.concatenate(distances, dim=0)
531
+ return dist
532
+
533
+ device = input_images.device
534
+ dtype = input_images.dtype
535
+ np_dtype = np.float32
536
+
537
+ original_input = input_images.clone()
538
+ n_img = input_images.shape[0]
539
+ ori_shape = input_images.shape
540
+
541
+ if max_res is not None:
542
+ scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
543
+ if scale_factor < 1:
544
+ downscaler = torch.nn.Upsample(
545
+ scale_factor=scale_factor, mode="nearest"
546
+ )
547
+ input_images = downscaler(torch.from_numpy(input_images)).numpy()
548
+
549
+ # init guess
550
+ _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
551
+ _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
552
+ s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
553
+ t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
554
+ x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
555
+
556
+ input_images = input_images.to(device)
557
+
558
+ # objective function
559
+ def closure(x):
560
+ l = len(x)
561
+ s = x[: int(l / 2)]
562
+ t = x[int(l / 2) :]
563
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
564
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
565
+
566
+ transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
567
+ dists = inter_distances(transformed_arrays)
568
+ sqrt_dist = torch.sqrt(torch.mean(dists**2))
569
+
570
+ if "mean" == reduction:
571
+ pred = torch.mean(transformed_arrays, dim=0)
572
+ elif "median" == reduction:
573
+ pred = torch.median(transformed_arrays, dim=0).values
574
+ else:
575
+ raise ValueError
576
+
577
+ near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
578
+ far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
579
+
580
+ err = sqrt_dist + (near_err + far_err) * regularizer_strength
581
+ err = err.detach().cpu().numpy().astype(np_dtype)
582
+ return err
583
+
584
+ res = minimize(
585
+ closure,
586
+ x,
587
+ method="BFGS",
588
+ tol=tol,
589
+ options={"maxiter": max_iter, "disp": False},
590
+ )
591
+ x = res.x
592
+ l = len(x)
593
+ s = x[: int(l / 2)]
594
+ t = x[int(l / 2) :]
595
+
596
+ # Prediction
597
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
598
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
599
+ transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
600
+ if "mean" == reduction:
601
+ aligned_images = torch.mean(transformed_arrays, dim=0)
602
+ std = torch.std(transformed_arrays, dim=0)
603
+ uncertainty = std
604
+ elif "median" == reduction:
605
+ aligned_images = torch.median(transformed_arrays, dim=0).values
606
+ # MAD (median absolute deviation) as uncertainty indicator
607
+ abs_dev = torch.abs(transformed_arrays - aligned_images)
608
+ mad = torch.median(abs_dev, dim=0).values
609
+ uncertainty = mad
610
+ else:
611
+ raise ValueError(f"Unknown reduction method: {reduction}")
612
+
613
+ # Scale and shift to [0, 1]
614
+ _min = torch.min(aligned_images)
615
+ _max = torch.max(aligned_images)
616
+ aligned_images = (aligned_images - _min) / (_max - _min)
617
+ uncertainty /= _max - _min
618
+
619
+ return aligned_images, uncertainty