fix
Browse files- config_web.yml +1 -1
- utils/sampling.py +2 -0
config_web.yml
CHANGED
@@ -43,7 +43,7 @@ training:
|
|
43 |
sampling:
|
44 |
batch_size: 1
|
45 |
last_only: True
|
46 |
-
sampling_timesteps:
|
47 |
|
48 |
optim:
|
49 |
weight_decay: 0.01
|
|
|
43 |
sampling:
|
44 |
batch_size: 1
|
45 |
last_only: True
|
46 |
+
sampling_timesteps: 50
|
47 |
|
48 |
optim:
|
49 |
weight_decay: 0.01
|
utils/sampling.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import torch
|
2 |
from torchvision.transforms.functional import crop
|
|
|
3 |
|
4 |
|
5 |
def compute_alpha(beta, t):
|
@@ -53,6 +54,7 @@ def generalized_steps_overlapping(x, x_cond, seq, model, b, eta=0., corners=None
|
|
53 |
x_grid_mask[:, :, hi:hi + p_size, wi:wi + p_size] += 1
|
54 |
|
55 |
for i, j in zip(reversed(seq), reversed(seq_next)):
|
|
|
56 |
t = (torch.ones(n) * i).to(x.device)
|
57 |
next_t = (torch.ones(n) * j).to(x.device)
|
58 |
at = compute_alpha(b, t.long())
|
|
|
1 |
import torch
|
2 |
from torchvision.transforms.functional import crop
|
3 |
+
import tqdm as tqdm
|
4 |
|
5 |
|
6 |
def compute_alpha(beta, t):
|
|
|
54 |
x_grid_mask[:, :, hi:hi + p_size, wi:wi + p_size] += 1
|
55 |
|
56 |
for i, j in zip(reversed(seq), reversed(seq_next)):
|
57 |
+
print(j)
|
58 |
t = (torch.ones(n) * i).to(x.device)
|
59 |
next_t = (torch.ones(n) * j).to(x.device)
|
60 |
at = compute_alpha(b, t.long())
|