Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
bc53ac3
1
Parent(s):
c81908d
support cond attn based discriminator
Browse files- pytorch_fid/fid_score.py +1 -1
- score_sde/models/discriminator.py +159 -0
pytorch_fid/fid_score.py
CHANGED
|
@@ -148,7 +148,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', resize
|
|
| 148 |
|
| 149 |
for batch in tqdm(dataloader):
|
| 150 |
batch = batch.to(device)
|
| 151 |
-
print(batch.shape, batch.min(), batch.max)
|
| 152 |
with torch.no_grad():
|
| 153 |
pred = model(batch)[0]
|
| 154 |
|
|
|
|
| 148 |
|
| 149 |
for batch in tqdm(dataloader):
|
| 150 |
batch = batch.to(device)
|
| 151 |
+
#print(batch.shape, batch.min(), batch.max)
|
| 152 |
with torch.no_grad():
|
| 153 |
pred = model(batch)[0]
|
| 154 |
|
score_sde/models/discriminator.py
CHANGED
|
@@ -167,6 +167,87 @@ class Discriminator_small(nn.Module):
|
|
| 167 |
|
| 168 |
return out
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
class Discriminator_large(nn.Module):
|
| 172 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
|
@@ -239,3 +320,81 @@ class Discriminator_large(nn.Module):
|
|
| 239 |
out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
|
| 240 |
return out
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
return out
|
| 169 |
|
| 170 |
+
class SmallCondAttnDiscriminator(nn.Module):
|
| 171 |
+
"""A time-dependent discriminator for small images (CIFAR10, StackMNIST)."""
|
| 172 |
+
|
| 173 |
+
def __init__(self, nc = 3, ngf = 64, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
|
| 174 |
+
super().__init__()
|
| 175 |
+
# Gaussian random feature embedding layer for time
|
| 176 |
+
self.act = act
|
| 177 |
+
self.cond_attn = layers.CondAttnBlock(ngf*8, cond_size, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False)
|
| 178 |
+
|
| 179 |
+
self.t_embed = TimestepEmbedding(
|
| 180 |
+
embedding_dim=t_emb_dim,
|
| 181 |
+
hidden_dim=t_emb_dim,
|
| 182 |
+
output_dim=t_emb_dim,
|
| 183 |
+
act=act,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# Encoding layers where the resolution decreases
|
| 189 |
+
self.start_conv = conv2d(nc,ngf*2,1, padding=0)
|
| 190 |
+
self.conv1 = DownConvBlock(ngf*2, ngf*2, t_emb_dim = t_emb_dim,act=act)
|
| 191 |
+
|
| 192 |
+
self.conv2 = DownConvBlock(ngf*2, ngf*4, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
self.conv3 = DownConvBlock(ngf*4, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
self.conv4 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
self.final_conv = conv2d(ngf*8 + 1, ngf*8, 3,padding=1, init_scale=0.)
|
| 202 |
+
self.end_linear = dense(ngf*8, 1)
|
| 203 |
+
self.end_linear_cond = dense(ngf*8, 1)
|
| 204 |
+
#self.gn_cond = nn.GroupNorm(num_groups=32, num_channels=ngf*8, eps=1e-6)
|
| 205 |
+
|
| 206 |
+
self.stddev_group = 4
|
| 207 |
+
self.stddev_feat = 1
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def forward(self, x, t, x_t, cond=None):
|
| 211 |
+
t_embed = self.t_embed(t)
|
| 212 |
+
# if cond is not None:
|
| 213 |
+
# t_embed = t_embed + self.cond_proj(cond)
|
| 214 |
+
t_embed = self.act(t_embed)
|
| 215 |
+
input_x = torch.cat((x, x_t), dim = 1)
|
| 216 |
+
|
| 217 |
+
h0 = self.start_conv(input_x)
|
| 218 |
+
h1 = self.conv1(h0,t_embed)
|
| 219 |
+
|
| 220 |
+
h2 = self.conv2(h1,t_embed)
|
| 221 |
+
|
| 222 |
+
h3 = self.conv3(h2,t_embed)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
out = self.conv4(h3,t_embed)
|
| 226 |
+
|
| 227 |
+
batch, channel, height, width = out.shape
|
| 228 |
+
group = min(batch, self.stddev_group)
|
| 229 |
+
stddev = out.view(
|
| 230 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
| 231 |
+
)
|
| 232 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
| 233 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
| 234 |
+
stddev = stddev.repeat(group, 1, height, width)
|
| 235 |
+
out = torch.cat([out, stddev], 1)
|
| 236 |
+
|
| 237 |
+
out = self.final_conv(out)
|
| 238 |
+
out = self.act(out)
|
| 239 |
+
|
| 240 |
+
cond_pooled, cond, cond_mask = cond
|
| 241 |
+
|
| 242 |
+
out_cond = (self.cond_attn(out, cond, cond_mask))
|
| 243 |
+
|
| 244 |
+
out = out.view(out.shape[0], out.shape[1], -1).mean(2)
|
| 245 |
+
out_cond = out_cond.view(out_cond.shape[0], out_cond.shape[1], -1).mean(2)
|
| 246 |
+
out = self.end_linear(out) + self.end_linear_cond(out_cond)
|
| 247 |
+
return out
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
|
| 251 |
|
| 252 |
class Discriminator_large(nn.Module):
|
| 253 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
|
|
|
| 320 |
out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
|
| 321 |
return out
|
| 322 |
|
| 323 |
+
|
| 324 |
+
class CondAttnDiscriminator(nn.Module):
|
| 325 |
+
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
| 326 |
+
|
| 327 |
+
def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
|
| 328 |
+
super().__init__()
|
| 329 |
+
# Gaussian random feature embedding layer for time
|
| 330 |
+
self.act = act
|
| 331 |
+
self.cond_attn = layers.CondAttnBlock(ngf*8, cond_size, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False)
|
| 332 |
+
|
| 333 |
+
self.t_embed = TimestepEmbedding(
|
| 334 |
+
embedding_dim=t_emb_dim,
|
| 335 |
+
hidden_dim=t_emb_dim,
|
| 336 |
+
output_dim=t_emb_dim,
|
| 337 |
+
act=act,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.start_conv = conv2d(nc,ngf*2,1, padding=0)
|
| 341 |
+
self.conv1 = DownConvBlock(ngf*2, ngf*4, t_emb_dim = t_emb_dim, downsample = True, act=act)
|
| 342 |
+
|
| 343 |
+
self.conv2 = DownConvBlock(ngf*4, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
| 344 |
+
|
| 345 |
+
self.conv3 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
self.conv4 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
| 349 |
+
self.conv5 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
| 350 |
+
self.conv6 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
self.final_conv = conv2d(ngf*8 + 1, ngf*8, 3,padding=1)
|
| 354 |
+
self.end_linear = dense(ngf*8, 1)
|
| 355 |
+
self.end_linear_cond = dense(ngf*8, 1)
|
| 356 |
+
|
| 357 |
+
self.stddev_group = 4
|
| 358 |
+
self.stddev_feat = 1
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def forward(self, x, t, x_t, cond=None):
|
| 362 |
+
cond_pooled, cond, cond_mask = cond
|
| 363 |
+
|
| 364 |
+
t_embed = self.t_embed(t)
|
| 365 |
+
t_embed = self.act(t_embed)
|
| 366 |
+
|
| 367 |
+
input_x = torch.cat((x, x_t), dim = 1)
|
| 368 |
+
|
| 369 |
+
h = self.start_conv(input_x)
|
| 370 |
+
h = self.conv1(h,t_embed)
|
| 371 |
+
|
| 372 |
+
h = self.conv2(h,t_embed)
|
| 373 |
+
|
| 374 |
+
h = self.conv3(h,t_embed)
|
| 375 |
+
h = self.conv4(h,t_embed)
|
| 376 |
+
h = self.conv5(h,t_embed)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
out = self.conv6(h,t_embed)
|
| 380 |
+
|
| 381 |
+
batch, channel, height, width = out.shape
|
| 382 |
+
group = min(batch, self.stddev_group)
|
| 383 |
+
stddev = out.view(
|
| 384 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
| 385 |
+
)
|
| 386 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
| 387 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
| 388 |
+
stddev = stddev.repeat(group, 1, height, width)
|
| 389 |
+
out = torch.cat([out, stddev], 1)
|
| 390 |
+
|
| 391 |
+
out = self.final_conv(out)
|
| 392 |
+
out = self.act(out)
|
| 393 |
+
|
| 394 |
+
out_cond = self.cond_attn(out, cond, cond_mask)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
out = out.view(out.shape[0], out.shape[1], -1).mean(2)
|
| 398 |
+
out_cond = out_cond.view(out_cond.shape[0], out_cond.shape[1], -1).mean(2)
|
| 399 |
+
out = self.end_linear(out) + self.end_linear_cond(out_cond)
|
| 400 |
+
return out
|