lcipolina commited on
Commit
d041481
1 Parent(s): e45ff8e

Upload glide_text2im/respace.py

Browse files
Files changed (1) hide show
  1. glide_text2im/respace.py +117 -0
glide_text2im/respace.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for changing sampling schedules of a trained model.
3
+
4
+ Simplified from: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
5
+ """
6
+
7
+ import numpy as np
8
+ import torch as th
9
+
10
+ from .gaussian_diffusion import GaussianDiffusion
11
+
12
+
13
+ def space_timesteps(num_timesteps, section_counts):
14
+ """
15
+ Create a list of timesteps to use from an original diffusion process,
16
+ given the number of timesteps we want to take from equally-sized portions
17
+ of the original process.
18
+
19
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
20
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
21
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
22
+
23
+ :param num_timesteps: the number of diffusion steps in the original
24
+ process to divide up.
25
+ :param section_counts: either a list of numbers, or a string containing
26
+ comma-separated numbers, indicating the step count
27
+ per section. As a special case, use "ddimN" where N
28
+ is a number of steps to use the striding from the
29
+ DDIM paper.
30
+ :return: a set of diffusion steps from the original process to use.
31
+ """
32
+ if isinstance(section_counts, str):
33
+ if section_counts.startswith("ddim"):
34
+ desired_count = int(section_counts[len("ddim") :])
35
+ for i in range(1, num_timesteps):
36
+ if len(range(0, num_timesteps, i)) == desired_count:
37
+ return set(range(0, num_timesteps, i))
38
+ raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
39
+ elif section_counts == "fast27":
40
+ steps = space_timesteps(num_timesteps, "10,10,3,2,2")
41
+ # Help reduce DDIM artifacts from noisiest timesteps.
42
+ steps.remove(num_timesteps - 1)
43
+ steps.add(num_timesteps - 3)
44
+ return steps
45
+ section_counts = [int(x) for x in section_counts.split(",")]
46
+ size_per = num_timesteps // len(section_counts)
47
+ extra = num_timesteps % len(section_counts)
48
+ start_idx = 0
49
+ all_steps = []
50
+ for i, section_count in enumerate(section_counts):
51
+ size = size_per + (1 if i < extra else 0)
52
+ if size < section_count:
53
+ raise ValueError(f"cannot divide section of {size} steps into {section_count}")
54
+ if section_count <= 1:
55
+ frac_stride = 1
56
+ else:
57
+ frac_stride = (size - 1) / (section_count - 1)
58
+ cur_idx = 0.0
59
+ taken_steps = []
60
+ for _ in range(section_count):
61
+ taken_steps.append(start_idx + round(cur_idx))
62
+ cur_idx += frac_stride
63
+ all_steps += taken_steps
64
+ start_idx += size
65
+ return set(all_steps)
66
+
67
+
68
+ class SpacedDiffusion(GaussianDiffusion):
69
+ """
70
+ A diffusion process which can skip steps in a base diffusion process.
71
+
72
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
73
+ original diffusion process to retain.
74
+ :param kwargs: the kwargs to create the base diffusion process.
75
+ """
76
+
77
+ def __init__(self, use_timesteps, **kwargs):
78
+ self.use_timesteps = set(use_timesteps)
79
+ self.timestep_map = []
80
+ self.original_num_steps = len(kwargs["betas"])
81
+
82
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
83
+ last_alpha_cumprod = 1.0
84
+ new_betas = []
85
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
86
+ if i in self.use_timesteps:
87
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
88
+ last_alpha_cumprod = alpha_cumprod
89
+ self.timestep_map.append(i)
90
+ kwargs["betas"] = np.array(new_betas)
91
+ super().__init__(**kwargs)
92
+
93
+ def p_mean_variance(self, model, *args, **kwargs):
94
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
95
+
96
+ def condition_mean(self, cond_fn, *args, **kwargs):
97
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
98
+
99
+ def condition_score(self, cond_fn, *args, **kwargs):
100
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def _wrap_model(self, model):
103
+ if isinstance(model, _WrappedModel):
104
+ return model
105
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
106
+
107
+
108
+ class _WrappedModel:
109
+ def __init__(self, model, timestep_map, original_num_steps):
110
+ self.model = model
111
+ self.timestep_map = timestep_map
112
+ self.original_num_steps = original_num_steps
113
+
114
+ def __call__(self, x, ts, **kwargs):
115
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
116
+ new_ts = map_tensor[ts]
117
+ return self.model(x, new_ts, **kwargs)