HaisuGuan commited on
Commit
830022c
·
1 Parent(s): d223e1c
Files changed (2) hide show
  1. config_web.yml +1 -1
  2. utils/sampling.py +2 -0
config_web.yml CHANGED
@@ -43,7 +43,7 @@ training:
43
  sampling:
44
  batch_size: 1
45
  last_only: True
46
- sampling_timesteps: 100
47
 
48
  optim:
49
  weight_decay: 0.01
 
43
  sampling:
44
  batch_size: 1
45
  last_only: True
46
+ sampling_timesteps: 50
47
 
48
  optim:
49
  weight_decay: 0.01
utils/sampling.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  from torchvision.transforms.functional import crop
 
3
 
4
 
5
  def compute_alpha(beta, t):
@@ -53,6 +54,7 @@ def generalized_steps_overlapping(x, x_cond, seq, model, b, eta=0., corners=None
53
  x_grid_mask[:, :, hi:hi + p_size, wi:wi + p_size] += 1
54
 
55
  for i, j in zip(reversed(seq), reversed(seq_next)):
 
56
  t = (torch.ones(n) * i).to(x.device)
57
  next_t = (torch.ones(n) * j).to(x.device)
58
  at = compute_alpha(b, t.long())
 
1
  import torch
2
  from torchvision.transforms.functional import crop
3
+ import tqdm as tqdm
4
 
5
 
6
  def compute_alpha(beta, t):
 
54
  x_grid_mask[:, :, hi:hi + p_size, wi:wi + p_size] += 1
55
 
56
  for i, j in zip(reversed(seq), reversed(seq_next)):
57
+ print(j)
58
  t = (torch.ones(n) * i).to(x.device)
59
  next_t = (torch.ones(n) * j).to(x.device)
60
  at = compute_alpha(b, t.long())