Spaces:
Runtime error
Runtime error
Upload glide_text2im/clip/model_creation.py
Browse files
glide_text2im/clip/model_creation.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
import attr
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import yaml
|
10 |
+
from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer
|
11 |
+
|
12 |
+
from .encoders import ImageEncoder, TextEncoder
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def default_config_path() -> str:
|
17 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml")
|
18 |
+
|
19 |
+
|
20 |
+
@attr.s
|
21 |
+
class CLIPModel:
|
22 |
+
config: Dict[str, Any] = attr.ib()
|
23 |
+
text_encoder: nn.Module = attr.ib()
|
24 |
+
image_encoder: nn.Module = attr.ib()
|
25 |
+
logit_scale: torch.Tensor = attr.ib()
|
26 |
+
device: torch.device = attr.ib()
|
27 |
+
tokenizer: SimpleTokenizer = attr.ib()
|
28 |
+
|
29 |
+
def encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
30 |
+
tokens = []
|
31 |
+
lens = []
|
32 |
+
for prompt in prompts:
|
33 |
+
sub_tokens, sub_len = self.tokenizer.padded_tokens_and_len(
|
34 |
+
self.tokenizer.encode(prompt), self.text_encoder.max_text_len
|
35 |
+
)
|
36 |
+
tokens.append(sub_tokens)
|
37 |
+
lens.append(sub_len)
|
38 |
+
return (
|
39 |
+
torch.tensor(tokens).to(dtype=torch.long, device=self.device),
|
40 |
+
torch.tensor(lens).to(dtype=torch.long, device=self.device),
|
41 |
+
)
|
42 |
+
|
43 |
+
def text_embeddings(self, prompts: List[str]) -> torch.Tensor:
|
44 |
+
tokens, lens = self.encode_prompts(prompts)
|
45 |
+
z_t = self.text_encoder(tokens, lens)
|
46 |
+
return z_t / (torch.linalg.norm(z_t, dim=-1, keepdim=True) + 1e-12)
|
47 |
+
|
48 |
+
def image_embeddings(self, images: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
49 |
+
z_i = self.image_encoder((images + 1) * 127.5, t)
|
50 |
+
return z_i / (torch.linalg.norm(z_i, dim=-1, keepdim=True) + 1e-12)
|
51 |
+
|
52 |
+
def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]:
|
53 |
+
with torch.no_grad():
|
54 |
+
z_t = self.text_embeddings(prompts)
|
55 |
+
|
56 |
+
def cond_fn(x, t, grad_scale=grad_scale, **kwargs):
|
57 |
+
with torch.enable_grad():
|
58 |
+
x_var = x.detach().requires_grad_(True)
|
59 |
+
z_i = self.image_embeddings(x_var, t)
|
60 |
+
loss = torch.exp(self.logit_scale) * (z_t * z_i).sum()
|
61 |
+
grad = torch.autograd.grad(loss, x_var)[0].detach()
|
62 |
+
return grad * grad_scale
|
63 |
+
|
64 |
+
return cond_fn
|
65 |
+
|
66 |
+
|
67 |
+
def create_clip_model(
|
68 |
+
config_path: Optional[str] = None,
|
69 |
+
device: Optional[torch.device] = None,
|
70 |
+
tokenizer: Optional[SimpleTokenizer] = None,
|
71 |
+
) -> CLIPModel:
|
72 |
+
if config_path is None:
|
73 |
+
config_path = default_config_path()
|
74 |
+
if device is None:
|
75 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
76 |
+
if tokenizer is None:
|
77 |
+
tokenizer = SimpleTokenizer()
|
78 |
+
|
79 |
+
with open(config_path, "r") as f:
|
80 |
+
config = yaml.load(f, Loader=yaml.SafeLoader)
|
81 |
+
|
82 |
+
text_encoder = TextEncoder(
|
83 |
+
n_bpe_vocab=config["n_vocab"],
|
84 |
+
max_text_len=config["max_text_len"],
|
85 |
+
n_embd=config["n_embd"],
|
86 |
+
n_head=config["n_head_text"],
|
87 |
+
n_xf_blocks=config["n_xf_blocks_text"],
|
88 |
+
n_head_state=config["n_head_state_text"],
|
89 |
+
device=device,
|
90 |
+
)
|
91 |
+
|
92 |
+
image_encoder = ImageEncoder(
|
93 |
+
image_size=config["image_size"],
|
94 |
+
patch_size=config["patch_size"],
|
95 |
+
n_embd=config["n_embd"],
|
96 |
+
n_head=config["n_head_image"],
|
97 |
+
n_xf_blocks=config["n_xf_blocks_image"],
|
98 |
+
n_head_state=config["n_head_state_image"],
|
99 |
+
n_timestep=config["n_timesteps"],
|
100 |
+
device=device,
|
101 |
+
)
|
102 |
+
|
103 |
+
logit_scale = torch.tensor(
|
104 |
+
np.log(config["logit_scale"]),
|
105 |
+
dtype=torch.float32,
|
106 |
+
device=device,
|
107 |
+
requires_grad=False,
|
108 |
+
)
|
109 |
+
|
110 |
+
return CLIPModel(
|
111 |
+
config=config,
|
112 |
+
text_encoder=text_encoder,
|
113 |
+
image_encoder=image_encoder,
|
114 |
+
logit_scale=logit_scale,
|
115 |
+
device=device,
|
116 |
+
tokenizer=tokenizer,
|
117 |
+
)
|