Spaces:
Runtime error
Runtime error
Upload glide_text2im/respace.py
Browse files- glide_text2im/respace.py +117 -0
glide_text2im/respace.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for changing sampling schedules of a trained model.
|
3 |
+
|
4 |
+
Simplified from: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch as th
|
9 |
+
|
10 |
+
from .gaussian_diffusion import GaussianDiffusion
|
11 |
+
|
12 |
+
|
13 |
+
def space_timesteps(num_timesteps, section_counts):
|
14 |
+
"""
|
15 |
+
Create a list of timesteps to use from an original diffusion process,
|
16 |
+
given the number of timesteps we want to take from equally-sized portions
|
17 |
+
of the original process.
|
18 |
+
|
19 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
20 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
21 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
22 |
+
|
23 |
+
:param num_timesteps: the number of diffusion steps in the original
|
24 |
+
process to divide up.
|
25 |
+
:param section_counts: either a list of numbers, or a string containing
|
26 |
+
comma-separated numbers, indicating the step count
|
27 |
+
per section. As a special case, use "ddimN" where N
|
28 |
+
is a number of steps to use the striding from the
|
29 |
+
DDIM paper.
|
30 |
+
:return: a set of diffusion steps from the original process to use.
|
31 |
+
"""
|
32 |
+
if isinstance(section_counts, str):
|
33 |
+
if section_counts.startswith("ddim"):
|
34 |
+
desired_count = int(section_counts[len("ddim") :])
|
35 |
+
for i in range(1, num_timesteps):
|
36 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
37 |
+
return set(range(0, num_timesteps, i))
|
38 |
+
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
|
39 |
+
elif section_counts == "fast27":
|
40 |
+
steps = space_timesteps(num_timesteps, "10,10,3,2,2")
|
41 |
+
# Help reduce DDIM artifacts from noisiest timesteps.
|
42 |
+
steps.remove(num_timesteps - 1)
|
43 |
+
steps.add(num_timesteps - 3)
|
44 |
+
return steps
|
45 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
46 |
+
size_per = num_timesteps // len(section_counts)
|
47 |
+
extra = num_timesteps % len(section_counts)
|
48 |
+
start_idx = 0
|
49 |
+
all_steps = []
|
50 |
+
for i, section_count in enumerate(section_counts):
|
51 |
+
size = size_per + (1 if i < extra else 0)
|
52 |
+
if size < section_count:
|
53 |
+
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
|
54 |
+
if section_count <= 1:
|
55 |
+
frac_stride = 1
|
56 |
+
else:
|
57 |
+
frac_stride = (size - 1) / (section_count - 1)
|
58 |
+
cur_idx = 0.0
|
59 |
+
taken_steps = []
|
60 |
+
for _ in range(section_count):
|
61 |
+
taken_steps.append(start_idx + round(cur_idx))
|
62 |
+
cur_idx += frac_stride
|
63 |
+
all_steps += taken_steps
|
64 |
+
start_idx += size
|
65 |
+
return set(all_steps)
|
66 |
+
|
67 |
+
|
68 |
+
class SpacedDiffusion(GaussianDiffusion):
|
69 |
+
"""
|
70 |
+
A diffusion process which can skip steps in a base diffusion process.
|
71 |
+
|
72 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
73 |
+
original diffusion process to retain.
|
74 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, use_timesteps, **kwargs):
|
78 |
+
self.use_timesteps = set(use_timesteps)
|
79 |
+
self.timestep_map = []
|
80 |
+
self.original_num_steps = len(kwargs["betas"])
|
81 |
+
|
82 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
83 |
+
last_alpha_cumprod = 1.0
|
84 |
+
new_betas = []
|
85 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
86 |
+
if i in self.use_timesteps:
|
87 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
88 |
+
last_alpha_cumprod = alpha_cumprod
|
89 |
+
self.timestep_map.append(i)
|
90 |
+
kwargs["betas"] = np.array(new_betas)
|
91 |
+
super().__init__(**kwargs)
|
92 |
+
|
93 |
+
def p_mean_variance(self, model, *args, **kwargs):
|
94 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
95 |
+
|
96 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
97 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
98 |
+
|
99 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
100 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
101 |
+
|
102 |
+
def _wrap_model(self, model):
|
103 |
+
if isinstance(model, _WrappedModel):
|
104 |
+
return model
|
105 |
+
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
|
106 |
+
|
107 |
+
|
108 |
+
class _WrappedModel:
|
109 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
110 |
+
self.model = model
|
111 |
+
self.timestep_map = timestep_map
|
112 |
+
self.original_num_steps = original_num_steps
|
113 |
+
|
114 |
+
def __call__(self, x, ts, **kwargs):
|
115 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
116 |
+
new_ts = map_tensor[ts]
|
117 |
+
return self.model(x, new_ts, **kwargs)
|