Spaces:
Runtime error
Runtime error
Upload notebooks/text2im.ipynb
Browse files- notebooks/text2im.ipynb +251 -0
notebooks/text2im.ipynb
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"# Run this line in Colab to install the package if it is\n",
|
10 |
+
"# not already installed.\n",
|
11 |
+
"!pip install git+https://github.com/openai/glide-text2im"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"from PIL import Image\n",
|
21 |
+
"from IPython.display import display\n",
|
22 |
+
"import torch as th\n",
|
23 |
+
"\n",
|
24 |
+
"from glide_text2im.download import load_checkpoint\n",
|
25 |
+
"from glide_text2im.model_creation import (\n",
|
26 |
+
" create_model_and_diffusion,\n",
|
27 |
+
" model_and_diffusion_defaults,\n",
|
28 |
+
" model_and_diffusion_defaults_upsampler\n",
|
29 |
+
")"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": null,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"# This notebook supports both CPU and GPU.\n",
|
39 |
+
"# On CPU, generating one sample may take on the order of 20 minutes.\n",
|
40 |
+
"# On a GPU, it should be under a minute.\n",
|
41 |
+
"\n",
|
42 |
+
"has_cuda = th.cuda.is_available()\n",
|
43 |
+
"device = th.device('cpu' if not has_cuda else 'cuda')"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": null,
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"# Create base model.\n",
|
53 |
+
"options = model_and_diffusion_defaults()\n",
|
54 |
+
"options['use_fp16'] = has_cuda\n",
|
55 |
+
"options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
|
56 |
+
"model, diffusion = create_model_and_diffusion(**options)\n",
|
57 |
+
"model.eval()\n",
|
58 |
+
"if has_cuda:\n",
|
59 |
+
" model.convert_to_fp16()\n",
|
60 |
+
"model.to(device)\n",
|
61 |
+
"model.load_state_dict(load_checkpoint('base', device))\n",
|
62 |
+
"print('total base parameters', sum(x.numel() for x in model.parameters()))"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"# Create upsampler model.\n",
|
72 |
+
"options_up = model_and_diffusion_defaults_upsampler()\n",
|
73 |
+
"options_up['use_fp16'] = has_cuda\n",
|
74 |
+
"options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n",
|
75 |
+
"model_up, diffusion_up = create_model_and_diffusion(**options_up)\n",
|
76 |
+
"model_up.eval()\n",
|
77 |
+
"if has_cuda:\n",
|
78 |
+
" model_up.convert_to_fp16()\n",
|
79 |
+
"model_up.to(device)\n",
|
80 |
+
"model_up.load_state_dict(load_checkpoint('upsample', device))\n",
|
81 |
+
"print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": null,
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [],
|
89 |
+
"source": [
|
90 |
+
"def show_images(batch: th.Tensor):\n",
|
91 |
+
" \"\"\" Display a batch of images inline. \"\"\"\n",
|
92 |
+
" scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
|
93 |
+
" reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
|
94 |
+
" display(Image.fromarray(reshaped.numpy()))"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": null,
|
100 |
+
"metadata": {},
|
101 |
+
"outputs": [],
|
102 |
+
"source": [
|
103 |
+
"# Sampling parameters\n",
|
104 |
+
"prompt = \"an oil painting of a corgi\"\n",
|
105 |
+
"batch_size = 1\n",
|
106 |
+
"guidance_scale = 3.0\n",
|
107 |
+
"\n",
|
108 |
+
"# Tune this parameter to control the sharpness of 256x256 images.\n",
|
109 |
+
"# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n",
|
110 |
+
"upsample_temp = 0.997"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"execution_count": null,
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"##############################\n",
|
120 |
+
"# Sample from the base model #\n",
|
121 |
+
"##############################\n",
|
122 |
+
"\n",
|
123 |
+
"# Create the text tokens to feed to the model.\n",
|
124 |
+
"tokens = model.tokenizer.encode(prompt)\n",
|
125 |
+
"tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
|
126 |
+
" tokens, options['text_ctx']\n",
|
127 |
+
")\n",
|
128 |
+
"\n",
|
129 |
+
"# Create the classifier-free guidance tokens (empty)\n",
|
130 |
+
"full_batch_size = batch_size * 2\n",
|
131 |
+
"uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n",
|
132 |
+
" [], options['text_ctx']\n",
|
133 |
+
")\n",
|
134 |
+
"\n",
|
135 |
+
"# Pack the tokens together into model kwargs.\n",
|
136 |
+
"model_kwargs = dict(\n",
|
137 |
+
" tokens=th.tensor(\n",
|
138 |
+
" [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n",
|
139 |
+
" ),\n",
|
140 |
+
" mask=th.tensor(\n",
|
141 |
+
" [mask] * batch_size + [uncond_mask] * batch_size,\n",
|
142 |
+
" dtype=th.bool,\n",
|
143 |
+
" device=device,\n",
|
144 |
+
" ),\n",
|
145 |
+
")\n",
|
146 |
+
"\n",
|
147 |
+
"# Create a classifier-free guidance sampling function\n",
|
148 |
+
"def model_fn(x_t, ts, **kwargs):\n",
|
149 |
+
" half = x_t[: len(x_t) // 2]\n",
|
150 |
+
" combined = th.cat([half, half], dim=0)\n",
|
151 |
+
" model_out = model(combined, ts, **kwargs)\n",
|
152 |
+
" eps, rest = model_out[:, :3], model_out[:, 3:]\n",
|
153 |
+
" cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n",
|
154 |
+
" half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n",
|
155 |
+
" eps = th.cat([half_eps, half_eps], dim=0)\n",
|
156 |
+
" return th.cat([eps, rest], dim=1)\n",
|
157 |
+
"\n",
|
158 |
+
"# Sample from the base model.\n",
|
159 |
+
"model.del_cache()\n",
|
160 |
+
"samples = diffusion.p_sample_loop(\n",
|
161 |
+
" model_fn,\n",
|
162 |
+
" (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
|
163 |
+
" device=device,\n",
|
164 |
+
" clip_denoised=True,\n",
|
165 |
+
" progress=True,\n",
|
166 |
+
" model_kwargs=model_kwargs,\n",
|
167 |
+
" cond_fn=None,\n",
|
168 |
+
")[:batch_size]\n",
|
169 |
+
"model.del_cache()\n",
|
170 |
+
"\n",
|
171 |
+
"# Show the output\n",
|
172 |
+
"show_images(samples)"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": null,
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [],
|
180 |
+
"source": [
|
181 |
+
"##############################\n",
|
182 |
+
"# Upsample the 64x64 samples #\n",
|
183 |
+
"##############################\n",
|
184 |
+
"\n",
|
185 |
+
"tokens = model_up.tokenizer.encode(prompt)\n",
|
186 |
+
"tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n",
|
187 |
+
" tokens, options_up['text_ctx']\n",
|
188 |
+
")\n",
|
189 |
+
"\n",
|
190 |
+
"# Create the model conditioning dict.\n",
|
191 |
+
"model_kwargs = dict(\n",
|
192 |
+
" # Low-res image to upsample.\n",
|
193 |
+
" low_res=((samples+1)*127.5).round()/127.5 - 1,\n",
|
194 |
+
"\n",
|
195 |
+
" # Text tokens\n",
|
196 |
+
" tokens=th.tensor(\n",
|
197 |
+
" [tokens] * batch_size, device=device\n",
|
198 |
+
" ),\n",
|
199 |
+
" mask=th.tensor(\n",
|
200 |
+
" [mask] * batch_size,\n",
|
201 |
+
" dtype=th.bool,\n",
|
202 |
+
" device=device,\n",
|
203 |
+
" ),\n",
|
204 |
+
")\n",
|
205 |
+
"\n",
|
206 |
+
"# Sample from the base model.\n",
|
207 |
+
"model_up.del_cache()\n",
|
208 |
+
"up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n",
|
209 |
+
"up_samples = diffusion_up.ddim_sample_loop(\n",
|
210 |
+
" model_up,\n",
|
211 |
+
" up_shape,\n",
|
212 |
+
" noise=th.randn(up_shape, device=device) * upsample_temp,\n",
|
213 |
+
" device=device,\n",
|
214 |
+
" clip_denoised=True,\n",
|
215 |
+
" progress=True,\n",
|
216 |
+
" model_kwargs=model_kwargs,\n",
|
217 |
+
" cond_fn=None,\n",
|
218 |
+
")[:batch_size]\n",
|
219 |
+
"model_up.del_cache()\n",
|
220 |
+
"\n",
|
221 |
+
"# Show the output\n",
|
222 |
+
"show_images(up_samples)"
|
223 |
+
]
|
224 |
+
}
|
225 |
+
],
|
226 |
+
"metadata": {
|
227 |
+
"interpreter": {
|
228 |
+
"hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781"
|
229 |
+
},
|
230 |
+
"kernelspec": {
|
231 |
+
"display_name": "Python 3",
|
232 |
+
"language": "python",
|
233 |
+
"name": "python3"
|
234 |
+
},
|
235 |
+
"language_info": {
|
236 |
+
"codemirror_mode": {
|
237 |
+
"name": "ipython",
|
238 |
+
"version": 3
|
239 |
+
},
|
240 |
+
"file_extension": ".py",
|
241 |
+
"mimetype": "text/x-python",
|
242 |
+
"name": "python",
|
243 |
+
"nbconvert_exporter": "python",
|
244 |
+
"pygments_lexer": "ipython3",
|
245 |
+
"version": "3.7.3"
|
246 |
+
},
|
247 |
+
"accelerator": "GPU"
|
248 |
+
},
|
249 |
+
"nbformat": 4,
|
250 |
+
"nbformat_minor": 2
|
251 |
+
}
|