lcipolina commited on
Commit
1a593a8
1 Parent(s): e87036f

Upload notebooks/text2im.ipynb

Browse files
Files changed (1) hide show
  1. notebooks/text2im.ipynb +251 -0
notebooks/text2im.ipynb ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# Run this line in Colab to install the package if it is\n",
10
+ "# not already installed.\n",
11
+ "!pip install git+https://github.com/openai/glide-text2im"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from PIL import Image\n",
21
+ "from IPython.display import display\n",
22
+ "import torch as th\n",
23
+ "\n",
24
+ "from glide_text2im.download import load_checkpoint\n",
25
+ "from glide_text2im.model_creation import (\n",
26
+ " create_model_and_diffusion,\n",
27
+ " model_and_diffusion_defaults,\n",
28
+ " model_and_diffusion_defaults_upsampler\n",
29
+ ")"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "# This notebook supports both CPU and GPU.\n",
39
+ "# On CPU, generating one sample may take on the order of 20 minutes.\n",
40
+ "# On a GPU, it should be under a minute.\n",
41
+ "\n",
42
+ "has_cuda = th.cuda.is_available()\n",
43
+ "device = th.device('cpu' if not has_cuda else 'cuda')"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "# Create base model.\n",
53
+ "options = model_and_diffusion_defaults()\n",
54
+ "options['use_fp16'] = has_cuda\n",
55
+ "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
56
+ "model, diffusion = create_model_and_diffusion(**options)\n",
57
+ "model.eval()\n",
58
+ "if has_cuda:\n",
59
+ " model.convert_to_fp16()\n",
60
+ "model.to(device)\n",
61
+ "model.load_state_dict(load_checkpoint('base', device))\n",
62
+ "print('total base parameters', sum(x.numel() for x in model.parameters()))"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "# Create upsampler model.\n",
72
+ "options_up = model_and_diffusion_defaults_upsampler()\n",
73
+ "options_up['use_fp16'] = has_cuda\n",
74
+ "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n",
75
+ "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n",
76
+ "model_up.eval()\n",
77
+ "if has_cuda:\n",
78
+ " model_up.convert_to_fp16()\n",
79
+ "model_up.to(device)\n",
80
+ "model_up.load_state_dict(load_checkpoint('upsample', device))\n",
81
+ "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "def show_images(batch: th.Tensor):\n",
91
+ " \"\"\" Display a batch of images inline. \"\"\"\n",
92
+ " scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
93
+ " reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
94
+ " display(Image.fromarray(reshaped.numpy()))"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "# Sampling parameters\n",
104
+ "prompt = \"an oil painting of a corgi\"\n",
105
+ "batch_size = 1\n",
106
+ "guidance_scale = 3.0\n",
107
+ "\n",
108
+ "# Tune this parameter to control the sharpness of 256x256 images.\n",
109
+ "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n",
110
+ "upsample_temp = 0.997"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "##############################\n",
120
+ "# Sample from the base model #\n",
121
+ "##############################\n",
122
+ "\n",
123
+ "# Create the text tokens to feed to the model.\n",
124
+ "tokens = model.tokenizer.encode(prompt)\n",
125
+ "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
126
+ " tokens, options['text_ctx']\n",
127
+ ")\n",
128
+ "\n",
129
+ "# Create the classifier-free guidance tokens (empty)\n",
130
+ "full_batch_size = batch_size * 2\n",
131
+ "uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n",
132
+ " [], options['text_ctx']\n",
133
+ ")\n",
134
+ "\n",
135
+ "# Pack the tokens together into model kwargs.\n",
136
+ "model_kwargs = dict(\n",
137
+ " tokens=th.tensor(\n",
138
+ " [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n",
139
+ " ),\n",
140
+ " mask=th.tensor(\n",
141
+ " [mask] * batch_size + [uncond_mask] * batch_size,\n",
142
+ " dtype=th.bool,\n",
143
+ " device=device,\n",
144
+ " ),\n",
145
+ ")\n",
146
+ "\n",
147
+ "# Create a classifier-free guidance sampling function\n",
148
+ "def model_fn(x_t, ts, **kwargs):\n",
149
+ " half = x_t[: len(x_t) // 2]\n",
150
+ " combined = th.cat([half, half], dim=0)\n",
151
+ " model_out = model(combined, ts, **kwargs)\n",
152
+ " eps, rest = model_out[:, :3], model_out[:, 3:]\n",
153
+ " cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n",
154
+ " half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n",
155
+ " eps = th.cat([half_eps, half_eps], dim=0)\n",
156
+ " return th.cat([eps, rest], dim=1)\n",
157
+ "\n",
158
+ "# Sample from the base model.\n",
159
+ "model.del_cache()\n",
160
+ "samples = diffusion.p_sample_loop(\n",
161
+ " model_fn,\n",
162
+ " (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
163
+ " device=device,\n",
164
+ " clip_denoised=True,\n",
165
+ " progress=True,\n",
166
+ " model_kwargs=model_kwargs,\n",
167
+ " cond_fn=None,\n",
168
+ ")[:batch_size]\n",
169
+ "model.del_cache()\n",
170
+ "\n",
171
+ "# Show the output\n",
172
+ "show_images(samples)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "##############################\n",
182
+ "# Upsample the 64x64 samples #\n",
183
+ "##############################\n",
184
+ "\n",
185
+ "tokens = model_up.tokenizer.encode(prompt)\n",
186
+ "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n",
187
+ " tokens, options_up['text_ctx']\n",
188
+ ")\n",
189
+ "\n",
190
+ "# Create the model conditioning dict.\n",
191
+ "model_kwargs = dict(\n",
192
+ " # Low-res image to upsample.\n",
193
+ " low_res=((samples+1)*127.5).round()/127.5 - 1,\n",
194
+ "\n",
195
+ " # Text tokens\n",
196
+ " tokens=th.tensor(\n",
197
+ " [tokens] * batch_size, device=device\n",
198
+ " ),\n",
199
+ " mask=th.tensor(\n",
200
+ " [mask] * batch_size,\n",
201
+ " dtype=th.bool,\n",
202
+ " device=device,\n",
203
+ " ),\n",
204
+ ")\n",
205
+ "\n",
206
+ "# Sample from the base model.\n",
207
+ "model_up.del_cache()\n",
208
+ "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n",
209
+ "up_samples = diffusion_up.ddim_sample_loop(\n",
210
+ " model_up,\n",
211
+ " up_shape,\n",
212
+ " noise=th.randn(up_shape, device=device) * upsample_temp,\n",
213
+ " device=device,\n",
214
+ " clip_denoised=True,\n",
215
+ " progress=True,\n",
216
+ " model_kwargs=model_kwargs,\n",
217
+ " cond_fn=None,\n",
218
+ ")[:batch_size]\n",
219
+ "model_up.del_cache()\n",
220
+ "\n",
221
+ "# Show the output\n",
222
+ "show_images(up_samples)"
223
+ ]
224
+ }
225
+ ],
226
+ "metadata": {
227
+ "interpreter": {
228
+ "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781"
229
+ },
230
+ "kernelspec": {
231
+ "display_name": "Python 3",
232
+ "language": "python",
233
+ "name": "python3"
234
+ },
235
+ "language_info": {
236
+ "codemirror_mode": {
237
+ "name": "ipython",
238
+ "version": 3
239
+ },
240
+ "file_extension": ".py",
241
+ "mimetype": "text/x-python",
242
+ "name": "python",
243
+ "nbconvert_exporter": "python",
244
+ "pygments_lexer": "ipython3",
245
+ "version": "3.7.3"
246
+ },
247
+ "accelerator": "GPU"
248
+ },
249
+ "nbformat": 4,
250
+ "nbformat_minor": 2
251
+ }