liuyuan-pal commited on
Commit
973977c
·
1 Parent(s): 8e2f608
Files changed (2) hide show
  1. app.py +12 -8
  2. ldm/models/diffusion/sync_dreamer.py +3 -7
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import fire
9
  from omegaconf import OmegaConf
10
 
 
11
  from ldm.util import add_margin, instantiate_from_config
12
  from sam_utils import sam_init, sam_out_nosave
13
 
@@ -19,12 +20,12 @@ _DESCRIPTION = '''
19
  <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo="></a>
20
  <a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
21
  </div>
22
- Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
23
 
24
- Procedure:
25
- **Step 0**. Upload an image or select an example. ==> The foreground is masked out by SAM.
26
- **Step 1**. Select "Crop size" and click "Crop it". ==> The foreground object is centered and resized.
27
- **Step 2**. Select "Elevation angle "and click "Run generation". ==> Generate multiview images. (This costs about 2 min.)
28
  To reconstruct a NeRF or a 3D mesh from the generated images, please refer to our [github repository](https://github.com/liuyuan-pal/SyncDreamer).
29
  '''
30
  _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example shown in the left)."
@@ -74,8 +75,9 @@ def resize_inputs(image_input, crop_size):
74
  results = add_margin(ref_img_, size=256)
75
  return results
76
 
77
- def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
78
  if deployed:
 
79
  seed=int(seed)
80
  torch.random.manual_seed(seed)
81
  np.random.seed(seed)
@@ -97,7 +99,8 @@ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, el
97
  data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
98
 
99
  if deployed:
100
- x_sample = model.sample(data, cfg_scale, batch_view_num)
 
101
  else:
102
  x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
103
 
@@ -219,6 +222,7 @@ def run_demo():
219
  with gr.Accordion('Advanced options', open=False):
220
  cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
221
  sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=True, info='How many instance (16 images per instance)')
 
222
  batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
223
  seed = gr.Number(6033, label='Random seed', interactive=True)
224
  run_btn = gr.Button('Run generation', variant='primary', interactive=True)
@@ -235,7 +239,7 @@ def run_demo():
235
  crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
236
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
237
 
238
- run_btn.click(partial(generate, model), inputs=[batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
239
  .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
240
 
241
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
 
8
  import fire
9
  from omegaconf import OmegaConf
10
 
11
+ from ldm.models.diffusion.sync_dreamer import SyncDDIMSampler, SyncMultiviewDiffusion
12
  from ldm.util import add_margin, instantiate_from_config
13
  from sam_utils import sam_init, sam_out_nosave
14
 
 
20
  <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo="></a>
21
  <a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
22
  </div>
23
+ Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss </br>
24
 
25
+ Procedure: </br>
26
+ **Step 0**. Upload an image or select an example. ==> The foreground is masked out by SAM. </br>
27
+ **Step 1**. Select "Crop size" and click "Crop it". ==> The foreground object is centered and resized. </br>
28
+ **Step 2**. Select "Elevation angle "and click "Run generation". ==> Generate multiview images. (This costs about 2 min.) </br>
29
  To reconstruct a NeRF or a 3D mesh from the generated images, please refer to our [github repository](https://github.com/liuyuan-pal/SyncDreamer).
30
  '''
31
  _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example shown in the left)."
 
75
  results = add_margin(ref_img_, size=256)
76
  return results
77
 
78
+ def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
79
  if deployed:
80
+ assert isinstance(model, SyncMultiviewDiffusion)
81
  seed=int(seed)
82
  torch.random.manual_seed(seed)
83
  np.random.seed(seed)
 
99
  data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
100
 
101
  if deployed:
102
+ sampler = SyncDDIMSampler(model, sample_steps)
103
+ x_sample = model.sample(sampler, data, cfg_scale, batch_view_num)
104
  else:
105
  x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
106
 
 
222
  with gr.Accordion('Advanced options', open=False):
223
  cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
224
  sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=True, info='How many instance (16 images per instance)')
225
+ sample_steps = gr.Slider(40, 400, 200, step=10, label='Sample steps', interactive=True)
226
  batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
227
  seed = gr.Number(6033, label='Random seed', interactive=True)
228
  run_btn = gr.Button('Run generation', variant='primary', interactive=True)
 
239
  crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
240
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
241
 
242
+ run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
243
  .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
244
 
245
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
ldm/models/diffusion/sync_dreamer.py CHANGED
@@ -468,13 +468,9 @@ class SyncMultiviewDiffusion(pl.LightningModule):
468
  x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
469
  return x_noisy, noise
470
 
471
- def sample(self, batch, cfg_scale, batch_view_num, use_ddim=True,
472
- return_inter_results=False, inter_interval=50, inter_view_interval=2):
473
  _, clip_embed, input_info = self.prepare(batch)
474
- if use_ddim:
475
- x_sample, inter = self.ddim.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
476
- else:
477
- raise NotImplementedError
478
 
479
  N = x_sample.shape[1]
480
  x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
@@ -540,7 +536,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
540
  return [opt], scheduler
541
 
542
  class SyncDDIMSampler:
543
- def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., latent_size=32):
544
  super().__init__()
545
  self.model = model
546
  self.ddpm_num_timesteps = model.num_timesteps
 
468
  x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
469
  return x_noisy, noise
470
 
471
+ def sample(self, sampler, batch, cfg_scale, batch_view_num, return_inter_results=False, inter_interval=50, inter_view_interval=2):
 
472
  _, clip_embed, input_info = self.prepare(batch)
473
+ x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
 
 
 
474
 
475
  N = x_sample.shape[1]
476
  x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
 
536
  return [opt], scheduler
537
 
538
  class SyncDDIMSampler:
539
+ def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=1.0, latent_size=32):
540
  super().__init__()
541
  self.model = model
542
  self.ddpm_num_timesteps = model.num_timesteps