lcipolina commited on
Commit
9741d33
1 Parent(s): 46cc6d4

Upload glide_text2im/clip/model_creation.py

Browse files
Files changed (1) hide show
  1. glide_text2im/clip/model_creation.py +117 -0
glide_text2im/clip/model_creation.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
+
5
+ import attr
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import yaml
10
+ from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer
11
+
12
+ from .encoders import ImageEncoder, TextEncoder
13
+
14
+
15
+ @lru_cache()
16
+ def default_config_path() -> str:
17
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml")
18
+
19
+
20
+ @attr.s
21
+ class CLIPModel:
22
+ config: Dict[str, Any] = attr.ib()
23
+ text_encoder: nn.Module = attr.ib()
24
+ image_encoder: nn.Module = attr.ib()
25
+ logit_scale: torch.Tensor = attr.ib()
26
+ device: torch.device = attr.ib()
27
+ tokenizer: SimpleTokenizer = attr.ib()
28
+
29
+ def encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
30
+ tokens = []
31
+ lens = []
32
+ for prompt in prompts:
33
+ sub_tokens, sub_len = self.tokenizer.padded_tokens_and_len(
34
+ self.tokenizer.encode(prompt), self.text_encoder.max_text_len
35
+ )
36
+ tokens.append(sub_tokens)
37
+ lens.append(sub_len)
38
+ return (
39
+ torch.tensor(tokens).to(dtype=torch.long, device=self.device),
40
+ torch.tensor(lens).to(dtype=torch.long, device=self.device),
41
+ )
42
+
43
+ def text_embeddings(self, prompts: List[str]) -> torch.Tensor:
44
+ tokens, lens = self.encode_prompts(prompts)
45
+ z_t = self.text_encoder(tokens, lens)
46
+ return z_t / (torch.linalg.norm(z_t, dim=-1, keepdim=True) + 1e-12)
47
+
48
+ def image_embeddings(self, images: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
49
+ z_i = self.image_encoder((images + 1) * 127.5, t)
50
+ return z_i / (torch.linalg.norm(z_i, dim=-1, keepdim=True) + 1e-12)
51
+
52
+ def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]:
53
+ with torch.no_grad():
54
+ z_t = self.text_embeddings(prompts)
55
+
56
+ def cond_fn(x, t, grad_scale=grad_scale, **kwargs):
57
+ with torch.enable_grad():
58
+ x_var = x.detach().requires_grad_(True)
59
+ z_i = self.image_embeddings(x_var, t)
60
+ loss = torch.exp(self.logit_scale) * (z_t * z_i).sum()
61
+ grad = torch.autograd.grad(loss, x_var)[0].detach()
62
+ return grad * grad_scale
63
+
64
+ return cond_fn
65
+
66
+
67
+ def create_clip_model(
68
+ config_path: Optional[str] = None,
69
+ device: Optional[torch.device] = None,
70
+ tokenizer: Optional[SimpleTokenizer] = None,
71
+ ) -> CLIPModel:
72
+ if config_path is None:
73
+ config_path = default_config_path()
74
+ if device is None:
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ if tokenizer is None:
77
+ tokenizer = SimpleTokenizer()
78
+
79
+ with open(config_path, "r") as f:
80
+ config = yaml.load(f, Loader=yaml.SafeLoader)
81
+
82
+ text_encoder = TextEncoder(
83
+ n_bpe_vocab=config["n_vocab"],
84
+ max_text_len=config["max_text_len"],
85
+ n_embd=config["n_embd"],
86
+ n_head=config["n_head_text"],
87
+ n_xf_blocks=config["n_xf_blocks_text"],
88
+ n_head_state=config["n_head_state_text"],
89
+ device=device,
90
+ )
91
+
92
+ image_encoder = ImageEncoder(
93
+ image_size=config["image_size"],
94
+ patch_size=config["patch_size"],
95
+ n_embd=config["n_embd"],
96
+ n_head=config["n_head_image"],
97
+ n_xf_blocks=config["n_xf_blocks_image"],
98
+ n_head_state=config["n_head_state_image"],
99
+ n_timestep=config["n_timesteps"],
100
+ device=device,
101
+ )
102
+
103
+ logit_scale = torch.tensor(
104
+ np.log(config["logit_scale"]),
105
+ dtype=torch.float32,
106
+ device=device,
107
+ requires_grad=False,
108
+ )
109
+
110
+ return CLIPModel(
111
+ config=config,
112
+ text_encoder=text_encoder,
113
+ image_encoder=image_encoder,
114
+ logit_scale=logit_scale,
115
+ device=device,
116
+ tokenizer=tokenizer,
117
+ )