Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
7060b15
1
Parent(s):
a45817e
update models
Browse files- clip_encoder.py +21 -0
- run.py +14 -1
clip_encoder.py
CHANGED
|
@@ -62,3 +62,24 @@ class CLIPImageEncoder(nn.Module):
|
|
| 62 |
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
|
| 65 |
+
class OpenCLIPImageEncoder(nn.Module):
|
| 66 |
+
|
| 67 |
+
def __init__(self, model="ViT-B/32", pretrained="openai"):
|
| 68 |
+
super().__init__()
|
| 69 |
+
model, _, preprocess = open_clip.create_model_and_transforms(model, pretrained=pretrained)
|
| 70 |
+
self.tokenizer = open_clip.get_tokenizer(model)
|
| 71 |
+
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
| 72 |
+
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
| 73 |
+
mean = torch.tensor(CLIP_MEAN).view(1, 3, 1, 1)
|
| 74 |
+
std = torch.tensor(CLIP_STD).view(1, 3, 1, 1)
|
| 75 |
+
self.register_buffer("mean", mean)
|
| 76 |
+
self.register_buffer("std", std)
|
| 77 |
+
|
| 78 |
+
def forward_image(self, x):
|
| 79 |
+
x = torch.nn.functional.interpolate(x, mode='bicubic', size=(224, 224))
|
| 80 |
+
x = (x-self.mean)/self.std
|
| 81 |
+
return self.model.encode_image(x)
|
| 82 |
+
|
| 83 |
+
def forward_text(self, texts):
|
| 84 |
+
toks = self.tokenizer.tokenize(texts, truncate=True).to(self.mean.device)
|
| 85 |
+
return self.model.encode_text(toks)
|
run.py
CHANGED
|
@@ -237,7 +237,7 @@ def ddgan_laion2b_v2():
|
|
| 237 |
return cfg
|
| 238 |
|
| 239 |
def ddgan_ddb_v1():
|
| 240 |
-
cfg =
|
| 241 |
return cfg
|
| 242 |
|
| 243 |
def ddgan_sd_v11():
|
|
@@ -245,6 +245,17 @@ def ddgan_sd_v11():
|
|
| 245 |
cfg['model']['image_size'] = 512
|
| 246 |
return cfg
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
models = [
|
| 249 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
| 250 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
@@ -286,6 +297,8 @@ models = [
|
|
| 286 |
ddgan_sd_v11,
|
| 287 |
ddgan_laion2b_v2,
|
| 288 |
ddgan_ddb_v1,
|
|
|
|
|
|
|
| 289 |
]
|
| 290 |
|
| 291 |
def get_model(model_name):
|
|
|
|
| 237 |
return cfg
|
| 238 |
|
| 239 |
def ddgan_ddb_v1():
|
| 240 |
+
cfg = ddgan_sd_v10()
|
| 241 |
return cfg
|
| 242 |
|
| 243 |
def ddgan_sd_v11():
|
|
|
|
| 245 |
cfg['model']['image_size'] = 512
|
| 246 |
return cfg
|
| 247 |
|
| 248 |
+
def ddgan_ddb_v2():
|
| 249 |
+
cfg = ddgan_ddb_v1()
|
| 250 |
+
cfg['model']['num_timesteps'] = 1
|
| 251 |
+
return cfg
|
| 252 |
+
|
| 253 |
+
def ddgan_ddb_v3():
|
| 254 |
+
cfg = ddgan_ddb_v1()
|
| 255 |
+
cfg['model']['num_channels_dae'] = 192
|
| 256 |
+
cfg['model']['num_timesteps'] = 2
|
| 257 |
+
return cfg
|
| 258 |
+
|
| 259 |
models = [
|
| 260 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
| 261 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
|
|
| 297 |
ddgan_sd_v11,
|
| 298 |
ddgan_laion2b_v2,
|
| 299 |
ddgan_ddb_v1,
|
| 300 |
+
ddgan_ddb_v2,
|
| 301 |
+
ddgan_ddb_v3
|
| 302 |
]
|
| 303 |
|
| 304 |
def get_model(model_name):
|