Mariam-Elz commited on
Commit
37e231b
·
verified ·
1 Parent(s): cb33ff6

Upload imagedream/ldm/interface.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/interface.py +206 -205
imagedream/ldm/interface.py CHANGED
@@ -1,205 +1,206 @@
1
- from typing import List
2
- from functools import partial
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn as nn
7
-
8
- from .modules.diffusionmodules.util import (
9
- make_beta_schedule,
10
- extract_into_tensor,
11
- enforce_zero_terminal_snr,
12
- noise_like,
13
- )
14
- from .util import exists, default, instantiate_from_config
15
- from .modules.distributions.distributions import DiagonalGaussianDistribution
16
-
17
-
18
- class DiffusionWrapper(nn.Module):
19
- def __init__(self, diffusion_model):
20
- super().__init__()
21
- self.diffusion_model = diffusion_model
22
-
23
- def forward(self, *args, **kwargs):
24
- return self.diffusion_model(*args, **kwargs)
25
-
26
-
27
- class LatentDiffusionInterface(nn.Module):
28
- """a simple interface class for LDM inference"""
29
-
30
- def __init__(
31
- self,
32
- unet_config,
33
- clip_config,
34
- vae_config,
35
- parameterization="eps",
36
- scale_factor=0.18215,
37
- beta_schedule="linear",
38
- timesteps=1000,
39
- linear_start=0.00085,
40
- linear_end=0.0120,
41
- cosine_s=8e-3,
42
- given_betas=None,
43
- zero_snr=False,
44
- *args,
45
- **kwargs,
46
- ):
47
- super().__init__()
48
-
49
- unet = instantiate_from_config(unet_config)
50
- self.model = DiffusionWrapper(unet)
51
- self.clip_model = instantiate_from_config(clip_config)
52
- self.vae_model = instantiate_from_config(vae_config)
53
-
54
- self.parameterization = parameterization
55
- self.scale_factor = scale_factor
56
- self.register_schedule(
57
- given_betas=given_betas,
58
- beta_schedule=beta_schedule,
59
- timesteps=timesteps,
60
- linear_start=linear_start,
61
- linear_end=linear_end,
62
- cosine_s=cosine_s,
63
- zero_snr=zero_snr
64
- )
65
-
66
- def register_schedule(
67
- self,
68
- given_betas=None,
69
- beta_schedule="linear",
70
- timesteps=1000,
71
- linear_start=1e-4,
72
- linear_end=2e-2,
73
- cosine_s=8e-3,
74
- zero_snr=False
75
- ):
76
- if exists(given_betas):
77
- betas = given_betas
78
- else:
79
- betas = make_beta_schedule(
80
- beta_schedule,
81
- timesteps,
82
- linear_start=linear_start,
83
- linear_end=linear_end,
84
- cosine_s=cosine_s,
85
- )
86
- if zero_snr:
87
- print("--- using zero snr---")
88
- betas = enforce_zero_terminal_snr(betas).numpy()
89
- alphas = 1.0 - betas
90
- alphas_cumprod = np.cumprod(alphas, axis=0)
91
- alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
92
-
93
- (timesteps,) = betas.shape
94
- self.num_timesteps = int(timesteps)
95
- self.linear_start = linear_start
96
- self.linear_end = linear_end
97
- assert (
98
- alphas_cumprod.shape[0] == self.num_timesteps
99
- ), "alphas have to be defined for each timestep"
100
-
101
- to_torch = partial(torch.tensor, dtype=torch.float32)
102
-
103
- self.register_buffer("betas", to_torch(betas))
104
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
105
- self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
106
-
107
- # calculations for diffusion q(x_t | x_{t-1}) and others
108
- self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
109
- self.register_buffer(
110
- "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
111
- )
112
- self.register_buffer(
113
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
114
- )
115
- self.register_buffer(
116
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
117
- )
118
- self.register_buffer(
119
- "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
120
- )
121
-
122
- # calculations for posterior q(x_{t-1} | x_t, x_0)
123
- self.v_posterior = 0
124
- posterior_variance = (1 - self.v_posterior) * betas * (
125
- 1.0 - alphas_cumprod_prev
126
- ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
127
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
128
- self.register_buffer("posterior_variance", to_torch(posterior_variance))
129
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
130
- self.register_buffer(
131
- "posterior_log_variance_clipped",
132
- to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
133
- )
134
- self.register_buffer(
135
- "posterior_mean_coef1",
136
- to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
137
- )
138
- self.register_buffer(
139
- "posterior_mean_coef2",
140
- to_torch(
141
- (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
142
- ),
143
- )
144
-
145
- def q_sample(self, x_start, t, noise=None):
146
- noise = default(noise, lambda: torch.randn_like(x_start))
147
- return (
148
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
149
- + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
150
- * noise
151
- )
152
-
153
- def get_v(self, x, noise, t):
154
- return (
155
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
156
- - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
157
- )
158
-
159
- def predict_start_from_noise(self, x_t, t, noise):
160
- return (
161
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
162
- - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
163
- * noise
164
- )
165
-
166
- def predict_start_from_z_and_v(self, x_t, t, v):
167
- return (
168
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
169
- - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
170
- )
171
-
172
- def predict_eps_from_z_and_v(self, x_t, t, v):
173
- return (
174
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
175
- + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
176
- * x_t
177
- )
178
-
179
- def apply_model(self, x_noisy, t, cond, **kwargs):
180
- assert isinstance(cond, dict), "cond has to be a dictionary"
181
- return self.model(x_noisy, t, **cond, **kwargs)
182
-
183
- def get_learned_conditioning(self, prompts: List[str]):
184
- return self.clip_model(prompts)
185
-
186
- def get_learned_image_conditioning(self, images):
187
- return self.clip_model.forward_image(images)
188
-
189
- def get_first_stage_encoding(self, encoder_posterior):
190
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
191
- z = encoder_posterior.sample()
192
- elif isinstance(encoder_posterior, torch.Tensor):
193
- z = encoder_posterior
194
- else:
195
- raise NotImplementedError(
196
- f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
197
- )
198
- return self.scale_factor * z
199
-
200
- def encode_first_stage(self, x):
201
- return self.vae_model.encode(x)
202
-
203
- def decode_first_stage(self, z):
204
- z = 1.0 / self.scale_factor * z
205
- return self.vae_model.decode(z)
 
 
1
+ from typing import List
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .modules.diffusionmodules.util import (
9
+ make_beta_schedule,
10
+ extract_into_tensor,
11
+ enforce_zero_terminal_snr,
12
+ noise_like,
13
+ )
14
+ from .util import exists, default, instantiate_from_config
15
+ from .modules.distributions.distributions import DiagonalGaussianDistribution
16
+
17
+
18
+ class DiffusionWrapper(nn.Module):
19
+ def __init__(self, diffusion_model):
20
+ super().__init__()
21
+ self.diffusion_model = diffusion_model
22
+
23
+ def forward(self, *args, **kwargs):
24
+ return self.diffusion_model(*args, **kwargs)
25
+
26
+
27
+ class LatentDiffusionInterface(nn.Module):
28
+ """a simple interface class for LDM inference"""
29
+
30
+ def __init__(
31
+ self,
32
+ unet_config,
33
+ clip_config,
34
+ vae_config,
35
+ parameterization="eps",
36
+ scale_factor=0.18215,
37
+ beta_schedule="linear",
38
+ timesteps=1000,
39
+ linear_start=0.00085,
40
+ linear_end=0.0120,
41
+ cosine_s=8e-3,
42
+ given_betas=None,
43
+ zero_snr=False,
44
+ *args,
45
+ **kwargs,
46
+ ):
47
+ super().__init__()
48
+
49
+ unet = instantiate_from_config(unet_config)
50
+ self.model = DiffusionWrapper(unet)
51
+ self.clip_model = instantiate_from_config(clip_config)
52
+ self.vae_model = instantiate_from_config(vae_config)
53
+
54
+ self.parameterization = parameterization
55
+ self.scale_factor = scale_factor
56
+ self.register_schedule(
57
+ given_betas=given_betas,
58
+ beta_schedule=beta_schedule,
59
+ timesteps=timesteps,
60
+ linear_start=linear_start,
61
+ linear_end=linear_end,
62
+ cosine_s=cosine_s,
63
+ zero_snr=zero_snr
64
+ )
65
+
66
+ def register_schedule(
67
+ self,
68
+ given_betas=None,
69
+ beta_schedule="linear",
70
+ timesteps=1000,
71
+ linear_start=1e-4,
72
+ linear_end=2e-2,
73
+ cosine_s=8e-3,
74
+ zero_snr=False
75
+ ):
76
+ if exists(given_betas):
77
+ betas = given_betas
78
+ else:
79
+ betas = make_beta_schedule(
80
+ beta_schedule,
81
+ timesteps,
82
+ linear_start=linear_start,
83
+ linear_end=linear_end,
84
+ cosine_s=cosine_s,
85
+ )
86
+ if zero_snr:
87
+ print("--- using zero snr---")
88
+ betas = enforce_zero_terminal_snr(betas).numpy()
89
+ alphas = 1.0 - betas
90
+ alphas_cumprod = np.cumprod(alphas, axis=0)
91
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
92
+
93
+ (timesteps,) = betas.shape
94
+ self.num_timesteps = int(timesteps)
95
+ self.linear_start = linear_start
96
+ self.linear_end = linear_end
97
+ assert (
98
+ alphas_cumprod.shape[0] == self.num_timesteps
99
+ ), "alphas have to be defined for each timestep"
100
+
101
+ to_torch = partial(torch.tensor, dtype=torch.float32)
102
+
103
+ self.register_buffer("betas", to_torch(betas))
104
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
105
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
106
+
107
+ # calculations for diffusion q(x_t | x_{t-1}) and others
108
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
109
+ self.register_buffer(
110
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
111
+ )
112
+ self.register_buffer(
113
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
114
+ )
115
+ eps = 1e-8 # adding small epsilon value to avoid devide by zero error
116
+ self.register_buffer(
117
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps)))
118
+ )
119
+ self.register_buffer(
120
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1))
121
+ )
122
+
123
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
124
+ self.v_posterior = 0
125
+ posterior_variance = (1 - self.v_posterior) * betas * (
126
+ 1.0 - alphas_cumprod_prev
127
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
128
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
129
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
130
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
131
+ self.register_buffer(
132
+ "posterior_log_variance_clipped",
133
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
134
+ )
135
+ self.register_buffer(
136
+ "posterior_mean_coef1",
137
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
138
+ )
139
+ self.register_buffer(
140
+ "posterior_mean_coef2",
141
+ to_torch(
142
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
143
+ ),
144
+ )
145
+
146
+ def q_sample(self, x_start, t, noise=None):
147
+ noise = default(noise, lambda: torch.randn_like(x_start))
148
+ return (
149
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
150
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
151
+ * noise
152
+ )
153
+
154
+ def get_v(self, x, noise, t):
155
+ return (
156
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
157
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
158
+ )
159
+
160
+ def predict_start_from_noise(self, x_t, t, noise):
161
+ return (
162
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
163
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
164
+ * noise
165
+ )
166
+
167
+ def predict_start_from_z_and_v(self, x_t, t, v):
168
+ return (
169
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
170
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
171
+ )
172
+
173
+ def predict_eps_from_z_and_v(self, x_t, t, v):
174
+ return (
175
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
176
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
177
+ * x_t
178
+ )
179
+
180
+ def apply_model(self, x_noisy, t, cond, **kwargs):
181
+ assert isinstance(cond, dict), "cond has to be a dictionary"
182
+ return self.model(x_noisy, t, **cond, **kwargs)
183
+
184
+ def get_learned_conditioning(self, prompts: List[str]):
185
+ return self.clip_model(prompts)
186
+
187
+ def get_learned_image_conditioning(self, images):
188
+ return self.clip_model.forward_image(images)
189
+
190
+ def get_first_stage_encoding(self, encoder_posterior):
191
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
192
+ z = encoder_posterior.sample()
193
+ elif isinstance(encoder_posterior, torch.Tensor):
194
+ z = encoder_posterior
195
+ else:
196
+ raise NotImplementedError(
197
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
198
+ )
199
+ return self.scale_factor * z
200
+
201
+ def encode_first_stage(self, x):
202
+ return self.vae_model.encode(x)
203
+
204
+ def decode_first_stage(self, z):
205
+ z = 1.0 / self.scale_factor * z
206
+ return self.vae_model.decode(z)