lcipolina commited on
Commit
dc419d1
1 Parent(s): 25c2b33

Upload glide_text2im/model_creation.py

Browse files
Files changed (1) hide show
  1. glide_text2im/model_creation.py +195 -0
glide_text2im/model_creation.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glide_text2im.gaussian_diffusion import get_named_beta_schedule
2
+ from glide_text2im.respace import SpacedDiffusion, space_timesteps
3
+ from glide_text2im.text2im_model import (
4
+ InpaintText2ImUNet,
5
+ SuperResInpaintText2ImUnet,
6
+ SuperResText2ImUNet,
7
+ Text2ImUNet,
8
+ )
9
+ from glide_text2im.tokenizer.bpe import get_encoder
10
+
11
+
12
+ def model_and_diffusion_defaults():
13
+ return dict(
14
+ image_size=64,
15
+ num_channels=192,
16
+ num_res_blocks=3,
17
+ channel_mult="",
18
+ num_heads=1,
19
+ num_head_channels=64,
20
+ num_heads_upsample=-1,
21
+ attention_resolutions="32,16,8",
22
+ dropout=0.1,
23
+ text_ctx=128,
24
+ xf_width=512,
25
+ xf_layers=16,
26
+ xf_heads=8,
27
+ xf_final_ln=True,
28
+ xf_padding=True,
29
+ diffusion_steps=1000,
30
+ noise_schedule="squaredcos_cap_v2",
31
+ timestep_respacing="",
32
+ use_scale_shift_norm=True,
33
+ resblock_updown=True,
34
+ use_fp16=True,
35
+ cache_text_emb=False,
36
+ inpaint=False,
37
+ super_res=False,
38
+ )
39
+
40
+
41
+ def model_and_diffusion_defaults_upsampler():
42
+ result = model_and_diffusion_defaults()
43
+ result.update(
44
+ dict(
45
+ image_size=256,
46
+ num_res_blocks=2,
47
+ noise_schedule="linear",
48
+ super_res=True,
49
+ )
50
+ )
51
+ return result
52
+
53
+
54
+ def create_model_and_diffusion(
55
+ image_size,
56
+ num_channels,
57
+ num_res_blocks,
58
+ channel_mult,
59
+ num_heads,
60
+ num_head_channels,
61
+ num_heads_upsample,
62
+ attention_resolutions,
63
+ dropout,
64
+ text_ctx,
65
+ xf_width,
66
+ xf_layers,
67
+ xf_heads,
68
+ xf_final_ln,
69
+ xf_padding,
70
+ diffusion_steps,
71
+ noise_schedule,
72
+ timestep_respacing,
73
+ use_scale_shift_norm,
74
+ resblock_updown,
75
+ use_fp16,
76
+ cache_text_emb,
77
+ inpaint,
78
+ super_res,
79
+ ):
80
+ model = create_model(
81
+ image_size,
82
+ num_channels,
83
+ num_res_blocks,
84
+ channel_mult=channel_mult,
85
+ attention_resolutions=attention_resolutions,
86
+ num_heads=num_heads,
87
+ num_head_channels=num_head_channels,
88
+ num_heads_upsample=num_heads_upsample,
89
+ use_scale_shift_norm=use_scale_shift_norm,
90
+ dropout=dropout,
91
+ text_ctx=text_ctx,
92
+ xf_width=xf_width,
93
+ xf_layers=xf_layers,
94
+ xf_heads=xf_heads,
95
+ xf_final_ln=xf_final_ln,
96
+ xf_padding=xf_padding,
97
+ resblock_updown=resblock_updown,
98
+ use_fp16=use_fp16,
99
+ cache_text_emb=cache_text_emb,
100
+ inpaint=inpaint,
101
+ super_res=super_res,
102
+ )
103
+ diffusion = create_gaussian_diffusion(
104
+ steps=diffusion_steps,
105
+ noise_schedule=noise_schedule,
106
+ timestep_respacing=timestep_respacing,
107
+ )
108
+ return model, diffusion
109
+
110
+
111
+ def create_model(
112
+ image_size,
113
+ num_channels,
114
+ num_res_blocks,
115
+ channel_mult,
116
+ attention_resolutions,
117
+ num_heads,
118
+ num_head_channels,
119
+ num_heads_upsample,
120
+ use_scale_shift_norm,
121
+ dropout,
122
+ text_ctx,
123
+ xf_width,
124
+ xf_layers,
125
+ xf_heads,
126
+ xf_final_ln,
127
+ xf_padding,
128
+ resblock_updown,
129
+ use_fp16,
130
+ cache_text_emb,
131
+ inpaint,
132
+ super_res,
133
+ ):
134
+ if channel_mult == "":
135
+ if image_size == 256:
136
+ channel_mult = (1, 1, 2, 2, 4, 4)
137
+ elif image_size == 128:
138
+ channel_mult = (1, 1, 2, 3, 4)
139
+ elif image_size == 64:
140
+ channel_mult = (1, 2, 3, 4)
141
+ else:
142
+ raise ValueError(f"unsupported image size: {image_size}")
143
+ else:
144
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
145
+ assert 2 ** (len(channel_mult) + 2) == image_size
146
+
147
+ attention_ds = []
148
+ for res in attention_resolutions.split(","):
149
+ attention_ds.append(image_size // int(res))
150
+
151
+ if inpaint and super_res:
152
+ model_cls = SuperResInpaintText2ImUnet
153
+ elif inpaint:
154
+ model_cls = InpaintText2ImUNet
155
+ elif super_res:
156
+ model_cls = SuperResText2ImUNet
157
+ else:
158
+ model_cls = Text2ImUNet
159
+ return model_cls(
160
+ text_ctx=text_ctx,
161
+ xf_width=xf_width,
162
+ xf_layers=xf_layers,
163
+ xf_heads=xf_heads,
164
+ xf_final_ln=xf_final_ln,
165
+ tokenizer=get_encoder(),
166
+ xf_padding=xf_padding,
167
+ in_channels=3,
168
+ model_channels=num_channels,
169
+ out_channels=6,
170
+ num_res_blocks=num_res_blocks,
171
+ attention_resolutions=tuple(attention_ds),
172
+ dropout=dropout,
173
+ channel_mult=channel_mult,
174
+ use_fp16=use_fp16,
175
+ num_heads=num_heads,
176
+ num_head_channels=num_head_channels,
177
+ num_heads_upsample=num_heads_upsample,
178
+ use_scale_shift_norm=use_scale_shift_norm,
179
+ resblock_updown=resblock_updown,
180
+ cache_text_emb=cache_text_emb,
181
+ )
182
+
183
+
184
+ def create_gaussian_diffusion(
185
+ steps,
186
+ noise_schedule,
187
+ timestep_respacing,
188
+ ):
189
+ betas = get_named_beta_schedule(noise_schedule, steps)
190
+ if not timestep_respacing:
191
+ timestep_respacing = [steps]
192
+ return SpacedDiffusion(
193
+ use_timesteps=space_timesteps(steps, timestep_respacing),
194
+ betas=betas,
195
+ )