huu-ontocord commited on
Commit
d148ebf
·
verified ·
1 Parent(s): d4bd9ed

Update seed2_tokenizer.py

Browse files
Files changed (1) hide show
  1. seed2_tokenizer.py +372 -10
seed2_tokenizer.py CHANGED
@@ -20,6 +20,34 @@
20
  SPDX-License-Identifier: BSD-3-Clause
21
  For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
22
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  import torch.nn as nn
25
  import torch
@@ -77,16 +105,6 @@ from timm.models.registry import register_model
77
  from timm.models.layers import trunc_normal_, DropPath
78
  from timm.models.helpers import named_apply, adapt_input_conv
79
 
80
- """
81
- * Copyright (c) 2023, salesforce.com, inc.
82
- * All rights reserved.
83
- * SPDX-License-Identifier: BSD-3-Clause
84
- * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
85
- * By Junnan Li
86
- * Based on huggingface code base
87
- * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
88
- """
89
-
90
  import math
91
  import os
92
  import warnings
@@ -124,6 +142,350 @@ from transformers.modeling_utils import (
124
  )
125
  from transformers.models.bert.configuration_bert import BertConfig
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  #torch.set_printoptions(profile="full")
128
 
129
  class DropPathEvaVit(nn.Module):
 
20
  SPDX-License-Identifier: BSD-3-Clause
21
  For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
22
  """
23
+ # Copyright (c) 2024 Black Forest Labs.
24
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
25
+ # SPDX-License-Identifier: Apache-2.0
26
+ #
27
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
28
+ #
29
+ # Original file was released under Apache-2.0, with the full license text
30
+ # available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
31
+ #
32
+ # This modified file is released under the same license.
33
+
34
+
35
+ """
36
+ * Copyright (c) 2023, salesforce.com, inc.
37
+ * All rights reserved.
38
+ * SPDX-License-Identifier: BSD-3-Clause
39
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
40
+ * By Junnan Li
41
+ * Based on huggingface code base
42
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
43
+ """
44
+
45
+ from dataclasses import dataclass
46
+
47
+ import torch
48
+ from einops import rearrange
49
+ from torch import Tensor, nn
50
+ from safetensors.torch import load_file as load_sft
51
 
52
  import torch.nn as nn
53
  import torch
 
105
  from timm.models.layers import trunc_normal_, DropPath
106
  from timm.models.helpers import named_apply, adapt_input_conv
107
 
 
 
 
 
 
 
 
 
 
 
108
  import math
109
  import os
110
  import warnings
 
142
  )
143
  from transformers.models.bert.configuration_bert import BertConfig
144
 
145
+
146
+
147
+ @dataclass
148
+ class AutoEncoderParams:
149
+ resolution: int
150
+ in_channels: int
151
+ downsample: int
152
+ ch: int
153
+ out_ch: int
154
+ ch_mult: list[int]
155
+ num_res_blocks: int
156
+ z_channels: int
157
+ scale_factor: float
158
+ shift_factor: float
159
+
160
+
161
+ def swish(x: Tensor) -> Tensor:
162
+ return x * torch.sigmoid(x)
163
+
164
+
165
+ class AttnBlock(nn.Module):
166
+ def __init__(self, in_channels: int):
167
+ super().__init__()
168
+ self.in_channels = in_channels
169
+
170
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
171
+
172
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
173
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
174
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
175
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
176
+
177
+ def attention(self, h_: Tensor) -> Tensor:
178
+ h_ = self.norm(h_)
179
+ q = self.q(h_)
180
+ k = self.k(h_)
181
+ v = self.v(h_)
182
+
183
+ b, c, h, w = q.shape
184
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
185
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
186
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
187
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
188
+
189
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
190
+
191
+ def forward(self, x: Tensor) -> Tensor:
192
+ return x + self.proj_out(self.attention(x))
193
+
194
+
195
+ class ResnetBlock(nn.Module):
196
+ def __init__(self, in_channels: int, out_channels: int):
197
+ super().__init__()
198
+ self.in_channels = in_channels
199
+ out_channels = in_channels if out_channels is None else out_channels
200
+ self.out_channels = out_channels
201
+
202
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
203
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
204
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
205
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
206
+ if self.in_channels != self.out_channels:
207
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
208
+
209
+ def forward(self, x):
210
+ h = x
211
+ h = self.norm1(h)
212
+ h = swish(h)
213
+ h = self.conv1(h)
214
+
215
+ h = self.norm2(h)
216
+ h = swish(h)
217
+ h = self.conv2(h)
218
+
219
+ if self.in_channels != self.out_channels:
220
+ x = self.nin_shortcut(x)
221
+
222
+ return x + h
223
+
224
+
225
+ class Downsample(nn.Module):
226
+ def __init__(self, in_channels: int):
227
+ super().__init__()
228
+ # no asymmetric padding in torch conv, must do it ourselves
229
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
230
+
231
+ def forward(self, x: Tensor):
232
+ pad = (0, 1, 0, 1)
233
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
234
+ x = self.conv(x)
235
+ return x
236
+
237
+
238
+ class Upsample(nn.Module):
239
+ def __init__(self, in_channels: int):
240
+ super().__init__()
241
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
242
+
243
+ def forward(self, x: Tensor):
244
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
245
+ x = self.conv(x)
246
+ return x
247
+
248
+
249
+ class Encoder(nn.Module):
250
+ def __init__(
251
+ self,
252
+ resolution: int,
253
+ in_channels: int,
254
+ ch: int,
255
+ ch_mult: list[int],
256
+ num_res_blocks: int,
257
+ z_channels: int,
258
+ ):
259
+ super().__init__()
260
+ self.ch = ch
261
+ self.num_resolutions = len(ch_mult)
262
+ self.num_res_blocks = num_res_blocks
263
+ self.resolution = resolution
264
+ self.in_channels = in_channels
265
+ # downsampling
266
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
267
+
268
+ curr_res = resolution
269
+ in_ch_mult = (1,) + tuple(ch_mult)
270
+ self.in_ch_mult = in_ch_mult
271
+ self.down = nn.ModuleList()
272
+ block_in = self.ch
273
+ for i_level in range(self.num_resolutions):
274
+ block = nn.ModuleList()
275
+ attn = nn.ModuleList()
276
+ block_in = ch * in_ch_mult[i_level]
277
+ block_out = ch * ch_mult[i_level]
278
+ for _ in range(self.num_res_blocks):
279
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
280
+ block_in = block_out
281
+ down = nn.Module()
282
+ down.block = block
283
+ down.attn = attn
284
+ if i_level != self.num_resolutions - 1:
285
+ down.downsample = Downsample(block_in)
286
+ curr_res = curr_res // 2
287
+ self.down.append(down)
288
+
289
+ # middle
290
+ self.mid = nn.Module()
291
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
292
+ self.mid.attn_1 = AttnBlock(block_in)
293
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
294
+
295
+ # end
296
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
297
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
298
+
299
+ def forward(self, x: Tensor) -> Tensor:
300
+ # downsampling
301
+ hs = [self.conv_in(x)]
302
+ for i_level in range(self.num_resolutions):
303
+ for i_block in range(self.num_res_blocks):
304
+ h = self.down[i_level].block[i_block](hs[-1])
305
+ if len(self.down[i_level].attn) > 0:
306
+ h = self.down[i_level].attn[i_block](h)
307
+ hs.append(h)
308
+ if i_level != self.num_resolutions - 1:
309
+ hs.append(self.down[i_level].downsample(hs[-1]))
310
+
311
+ # middle
312
+ h = hs[-1]
313
+ h = self.mid.block_1(h)
314
+ h = self.mid.attn_1(h)
315
+ h = self.mid.block_2(h)
316
+ # end
317
+ h = self.norm_out(h)
318
+ h = swish(h)
319
+ h = self.conv_out(h)
320
+ return h
321
+
322
+
323
+ class Decoder(nn.Module):
324
+ def __init__(
325
+ self,
326
+ ch: int,
327
+ out_ch: int,
328
+ ch_mult: list[int],
329
+ num_res_blocks: int,
330
+ in_channels: int,
331
+ resolution: int,
332
+ z_channels: int,
333
+ ):
334
+ super().__init__()
335
+ self.ch = ch
336
+ self.num_resolutions = len(ch_mult)
337
+ self.num_res_blocks = num_res_blocks
338
+ self.resolution = resolution
339
+ self.in_channels = in_channels
340
+ self.ffactor = 2 ** (self.num_resolutions - 1)
341
+
342
+ # compute in_ch_mult, block_in and curr_res at lowest res
343
+ block_in = ch * ch_mult[self.num_resolutions - 1]
344
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
345
+ self.z_shape = (1, z_channels, curr_res, curr_res)
346
+
347
+ # z to block_in
348
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
349
+
350
+ # middle
351
+ self.mid = nn.Module()
352
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
353
+ self.mid.attn_1 = AttnBlock(block_in)
354
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
355
+
356
+ # upsampling
357
+ self.up = nn.ModuleList()
358
+ for i_level in reversed(range(self.num_resolutions)):
359
+ block = nn.ModuleList()
360
+ attn = nn.ModuleList()
361
+ block_out = ch * ch_mult[i_level]
362
+ for _ in range(self.num_res_blocks + 1):
363
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
364
+ block_in = block_out
365
+ up = nn.Module()
366
+ up.block = block
367
+ up.attn = attn
368
+ if i_level != 0:
369
+ up.upsample = Upsample(block_in)
370
+ curr_res = curr_res * 2
371
+ self.up.insert(0, up) # prepend to get consistent order
372
+
373
+ # end
374
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
375
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
376
+
377
+ def forward(self, z: Tensor) -> Tensor:
378
+ # z to block_in
379
+ h = self.conv_in(z)
380
+
381
+ # middle
382
+ h = self.mid.block_1(h)
383
+ h = self.mid.attn_1(h)
384
+ h = self.mid.block_2(h)
385
+
386
+ # upsampling
387
+ for i_level in reversed(range(self.num_resolutions)):
388
+ for i_block in range(self.num_res_blocks + 1):
389
+ h = self.up[i_level].block[i_block](h)
390
+ if len(self.up[i_level].attn) > 0:
391
+ h = self.up[i_level].attn[i_block](h)
392
+ if i_level != 0:
393
+ h = self.up[i_level].upsample(h)
394
+
395
+ # end
396
+ h = self.norm_out(h)
397
+ h = swish(h)
398
+ h = self.conv_out(h)
399
+ return h
400
+
401
+
402
+ class DiagonalGaussian(nn.Module):
403
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
404
+ super().__init__()
405
+ self.sample = sample
406
+ self.chunk_dim = chunk_dim
407
+
408
+ def forward(self, z: Tensor) -> Tensor:
409
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
410
+ if self.sample:
411
+ std = torch.exp(0.5 * logvar)
412
+ return mean + std * torch.randn_like(mean)
413
+ else:
414
+ return mean
415
+
416
+
417
+ class AutoEncoder(nn.Module):
418
+ def __init__(self, params: AutoEncoderParams):
419
+ super().__init__()
420
+ self.encoder = Encoder(
421
+ resolution=params.resolution,
422
+ in_channels=params.in_channels,
423
+ ch=params.ch,
424
+ ch_mult=params.ch_mult,
425
+ num_res_blocks=params.num_res_blocks,
426
+ z_channels=params.z_channels,
427
+ )
428
+ self.decoder = Decoder(
429
+ resolution=params.resolution,
430
+ in_channels=params.in_channels,
431
+ ch=params.ch,
432
+ out_ch=params.out_ch,
433
+ ch_mult=params.ch_mult,
434
+ num_res_blocks=params.num_res_blocks,
435
+ z_channels=params.z_channels,
436
+ )
437
+ self.reg = DiagonalGaussian()
438
+
439
+ self.scale_factor = params.scale_factor
440
+ self.shift_factor = params.shift_factor
441
+
442
+ def encode(self, x: Tensor) -> Tensor:
443
+ z = self.reg(self.encoder(x))
444
+ z = self.scale_factor * (z - self.shift_factor)
445
+ return z
446
+
447
+ def decode(self, z: Tensor) -> Tensor:
448
+ z = z / self.scale_factor + self.shift_factor
449
+ return self.decoder(z)
450
+
451
+ def forward(self, x: Tensor) -> Tensor:
452
+ return self.decode(self.encode(x))
453
+
454
+
455
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
456
+ if len(missing) > 0 and len(unexpected) > 0:
457
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
458
+ print("\n" + "-" * 79 + "\n")
459
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
460
+ elif len(missing) > 0:
461
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
462
+ elif len(unexpected) > 0:
463
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
464
+
465
+
466
+ def load_ae(local_path: str) -> AutoEncoder:
467
+ ae_params = AutoEncoderParams(
468
+ resolution=256,
469
+ in_channels=3,
470
+ downsample=8,
471
+ ch=128,
472
+ out_ch=3,
473
+ ch_mult=[1, 2, 4, 4],
474
+ num_res_blocks=2,
475
+ z_channels=16,
476
+ scale_factor=0.3611,
477
+ shift_factor=0.1159,
478
+ )
479
+
480
+ # Loading the autoencoder
481
+ ae = AutoEncoder(ae_params)
482
+
483
+ if local_path is not None:
484
+ sd = load_sft(local_path)
485
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
486
+ print_load_warning(missing, unexpected)
487
+ return ae, ae_params
488
+
489
  #torch.set_printoptions(profile="full")
490
 
491
  class DropPathEvaVit(nn.Module):