Spaces:
Running
Running
feat: allow more configurations
Browse files
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -58,13 +58,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 58 |
tie_word_embeddings=False, # different modalities and sizes
|
| 59 |
do_sample=True,
|
| 60 |
# transformer variants
|
| 61 |
-
head_scale=False, # used in NormFormer
|
| 62 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 63 |
-
ln_positions="
|
|
|
|
| 64 |
use_cosine_attention=False, # used in Swin v2
|
| 65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 66 |
use_deepnet_scaling=False, # used in Deepnet
|
| 67 |
-
use_glu=
|
|
|
|
| 68 |
**kwargs,
|
| 69 |
):
|
| 70 |
# text normalizer
|
|
@@ -83,11 +84,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 83 |
"cogview",
|
| 84 |
"deepnet",
|
| 85 |
], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
|
|
|
|
|
|
|
| 86 |
self.ln_positions = ln_positions
|
| 87 |
self.use_cosine_attention = use_cosine_attention
|
| 88 |
self.tau_init = tau_init
|
| 89 |
self.use_deepnet_scaling = use_deepnet_scaling
|
| 90 |
self.use_glu = use_glu
|
|
|
|
| 91 |
|
| 92 |
# common parameters
|
| 93 |
self.encoder_vocab_size = encoder_vocab_size
|
|
|
|
| 58 |
tie_word_embeddings=False, # different modalities and sizes
|
| 59 |
do_sample=True,
|
| 60 |
# transformer variants
|
|
|
|
| 61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 62 |
+
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
|
| 63 |
+
head_scale=True, # used in NormFormer
|
| 64 |
use_cosine_attention=False, # used in Swin v2
|
| 65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 66 |
use_deepnet_scaling=False, # used in Deepnet
|
| 67 |
+
use_glu=True, # "GLU Variants Improve Transformer"
|
| 68 |
+
use_all_scale=True, # use scale in layernorm even when seemingly unnecessary
|
| 69 |
**kwargs,
|
| 70 |
):
|
| 71 |
# text normalizer
|
|
|
|
| 84 |
"cogview",
|
| 85 |
"deepnet",
|
| 86 |
], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
|
| 87 |
+
if ln_positions == "deepnet":
|
| 88 |
+
ln_positions = "postln"
|
| 89 |
self.ln_positions = ln_positions
|
| 90 |
self.use_cosine_attention = use_cosine_attention
|
| 91 |
self.tau_init = tau_init
|
| 92 |
self.use_deepnet_scaling = use_deepnet_scaling
|
| 93 |
self.use_glu = use_glu
|
| 94 |
+
self.use_all_scale = use_all_scale
|
| 95 |
|
| 96 |
# common parameters
|
| 97 |
self.encoder_vocab_size = encoder_vocab_size
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -375,7 +375,10 @@ class GLU(nn.Module):
|
|
| 375 |
|
| 376 |
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 377 |
x = norm(
|
| 378 |
-
self.config.ln_type,
|
|
|
|
|
|
|
|
|
|
| 379 |
)(x)
|
| 380 |
w = nn.Dense(
|
| 381 |
self.ffn_dim,
|
|
@@ -397,7 +400,10 @@ class GLU(nn.Module):
|
|
| 397 |
x = w * v
|
| 398 |
if self.config.ln_positions in ["normformer"]:
|
| 399 |
x = norm(
|
| 400 |
-
self.config.ln_type,
|
|
|
|
|
|
|
|
|
|
| 401 |
)(x)
|
| 402 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
| 403 |
x, deterministic=deterministic
|
|
@@ -434,7 +440,10 @@ class FFN(nn.Module):
|
|
| 434 |
)
|
| 435 |
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 436 |
x = norm(
|
| 437 |
-
self.config.ln_type,
|
|
|
|
|
|
|
|
|
|
| 438 |
)(x)
|
| 439 |
x = nn.Dense(
|
| 440 |
self.ffn_dim,
|
|
@@ -447,7 +456,10 @@ class FFN(nn.Module):
|
|
| 447 |
x = ACT2FN[self.config.activation_function](x)
|
| 448 |
if self.config.ln_positions in ["normformer"]:
|
| 449 |
x = norm(
|
| 450 |
-
self.config.ln_type,
|
|
|
|
|
|
|
|
|
|
| 451 |
)(x)
|
| 452 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
| 453 |
x, deterministic=deterministic
|
|
@@ -495,10 +507,13 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 495 |
|
| 496 |
embed_dim = self.config.d_model
|
| 497 |
residual = hidden_states
|
| 498 |
-
if self.config.ln_positions in ["normformer"]:
|
| 499 |
-
hidden_states = norm(
|
| 500 |
-
|
| 501 |
-
|
|
|
|
|
|
|
|
|
|
| 502 |
hidden_states, attn_weights = FlaxBartAttention(
|
| 503 |
config=self.config,
|
| 504 |
embed_dim=embed_dim,
|
|
@@ -509,7 +524,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 509 |
is_encoder=True,
|
| 510 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
| 511 |
|
| 512 |
-
if self.config.ln_positions in ["normformer", "swinv2"]:
|
| 513 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
| 514 |
hidden_states
|
| 515 |
)
|
|
@@ -517,7 +532,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 517 |
hidden_states, deterministic=deterministic
|
| 518 |
)
|
| 519 |
hidden_states = residual * res_gain + hidden_states
|
| 520 |
-
if self.config.ln_positions in ["
|
| 521 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
| 522 |
hidden_states
|
| 523 |
)
|
|
@@ -542,8 +557,12 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 542 |
)
|
| 543 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
| 544 |
hidden_states = residual * res_gain + hidden_states
|
| 545 |
-
if self.add_norm or self.config.ln_positions in ["
|
| 546 |
-
use_scale =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
hidden_states = norm(
|
| 548 |
self.config.ln_type,
|
| 549 |
dtype=self.dtype,
|
|
@@ -598,7 +617,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 598 |
self.config.ln_type,
|
| 599 |
dtype=self.dtype,
|
| 600 |
epsilon=1e-05,
|
| 601 |
-
use_scale=
|
| 602 |
)(hidden_states)
|
| 603 |
hidden_states, attn_weights = FlaxBartAttention(
|
| 604 |
config=self.config,
|
|
@@ -623,7 +642,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 623 |
hidden_states, deterministic=deterministic
|
| 624 |
)
|
| 625 |
hidden_states = residual * res_gain + hidden_states
|
| 626 |
-
if self.config.ln_positions in ["
|
| 627 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
| 628 |
hidden_states
|
| 629 |
)
|
|
@@ -637,7 +656,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 637 |
self.config.ln_type,
|
| 638 |
dtype=self.dtype,
|
| 639 |
epsilon=1e-05,
|
| 640 |
-
use_scale=
|
| 641 |
)(hidden_states)
|
| 642 |
hidden_states, cross_attn_weights = FlaxBartAttention(
|
| 643 |
config=self.config,
|
|
@@ -660,7 +679,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 660 |
hidden_states, deterministic=deterministic
|
| 661 |
)
|
| 662 |
hidden_states = residual * res_gain + hidden_states
|
| 663 |
-
if self.config.ln_positions in ["
|
| 664 |
hidden_states = norm(
|
| 665 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
| 666 |
)(hidden_states)
|
|
@@ -686,8 +705,12 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 686 |
)
|
| 687 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
| 688 |
hidden_states = residual * res_gain + hidden_states
|
| 689 |
-
if self.add_norm or self.config.ln_positions in ["
|
| 690 |
-
use_scale =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
hidden_states = norm(
|
| 692 |
self.config.ln_type,
|
| 693 |
dtype=self.dtype,
|
|
|
|
| 375 |
|
| 376 |
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 377 |
x = norm(
|
| 378 |
+
self.config.ln_type,
|
| 379 |
+
dtype=self.dtype,
|
| 380 |
+
epsilon=1e-05,
|
| 381 |
+
use_scale=self.config.use_all_scale,
|
| 382 |
)(x)
|
| 383 |
w = nn.Dense(
|
| 384 |
self.ffn_dim,
|
|
|
|
| 400 |
x = w * v
|
| 401 |
if self.config.ln_positions in ["normformer"]:
|
| 402 |
x = norm(
|
| 403 |
+
self.config.ln_type,
|
| 404 |
+
dtype=self.dtype,
|
| 405 |
+
epsilon=1e-05,
|
| 406 |
+
use_scale=self.config.use_all_scale,
|
| 407 |
)(x)
|
| 408 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
| 409 |
x, deterministic=deterministic
|
|
|
|
| 440 |
)
|
| 441 |
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 442 |
x = norm(
|
| 443 |
+
self.config.ln_type,
|
| 444 |
+
dtype=self.dtype,
|
| 445 |
+
epsilon=1e-05,
|
| 446 |
+
use_scale=self.config.use_all_scale,
|
| 447 |
)(x)
|
| 448 |
x = nn.Dense(
|
| 449 |
self.ffn_dim,
|
|
|
|
| 456 |
x = ACT2FN[self.config.activation_function](x)
|
| 457 |
if self.config.ln_positions in ["normformer"]:
|
| 458 |
x = norm(
|
| 459 |
+
self.config.ln_type,
|
| 460 |
+
dtype=self.dtype,
|
| 461 |
+
epsilon=1e-05,
|
| 462 |
+
use_scale=self.config.use_all_scale,
|
| 463 |
)(x)
|
| 464 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
| 465 |
x, deterministic=deterministic
|
|
|
|
| 507 |
|
| 508 |
embed_dim = self.config.d_model
|
| 509 |
residual = hidden_states
|
| 510 |
+
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 511 |
+
hidden_states = norm(
|
| 512 |
+
self.config.ln_type,
|
| 513 |
+
dtype=self.dtype,
|
| 514 |
+
epsilon=1e-05,
|
| 515 |
+
use_scale=self.config.use_all_scale,
|
| 516 |
+
)(hidden_states)
|
| 517 |
hidden_states, attn_weights = FlaxBartAttention(
|
| 518 |
config=self.config,
|
| 519 |
embed_dim=embed_dim,
|
|
|
|
| 524 |
is_encoder=True,
|
| 525 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
| 526 |
|
| 527 |
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
| 528 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
| 529 |
hidden_states
|
| 530 |
)
|
|
|
|
| 532 |
hidden_states, deterministic=deterministic
|
| 533 |
)
|
| 534 |
hidden_states = residual * res_gain + hidden_states
|
| 535 |
+
if self.config.ln_positions in ["postln"]:
|
| 536 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
| 537 |
hidden_states
|
| 538 |
)
|
|
|
|
| 557 |
)
|
| 558 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
| 559 |
hidden_states = residual * res_gain + hidden_states
|
| 560 |
+
if self.add_norm or self.config.ln_positions in ["postln"]:
|
| 561 |
+
use_scale = (
|
| 562 |
+
self.use_scale
|
| 563 |
+
or self.config.ln_positions == "postln"
|
| 564 |
+
or self.config.use_all_scale
|
| 565 |
+
)
|
| 566 |
hidden_states = norm(
|
| 567 |
self.config.ln_type,
|
| 568 |
dtype=self.dtype,
|
|
|
|
| 617 |
self.config.ln_type,
|
| 618 |
dtype=self.dtype,
|
| 619 |
epsilon=1e-05,
|
| 620 |
+
use_scale=self.config.use_all_scale,
|
| 621 |
)(hidden_states)
|
| 622 |
hidden_states, attn_weights = FlaxBartAttention(
|
| 623 |
config=self.config,
|
|
|
|
| 642 |
hidden_states, deterministic=deterministic
|
| 643 |
)
|
| 644 |
hidden_states = residual * res_gain + hidden_states
|
| 645 |
+
if self.config.ln_positions in ["postln"]:
|
| 646 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
| 647 |
hidden_states
|
| 648 |
)
|
|
|
|
| 656 |
self.config.ln_type,
|
| 657 |
dtype=self.dtype,
|
| 658 |
epsilon=1e-05,
|
| 659 |
+
use_scale=self.config.use_all_scale,
|
| 660 |
)(hidden_states)
|
| 661 |
hidden_states, cross_attn_weights = FlaxBartAttention(
|
| 662 |
config=self.config,
|
|
|
|
| 679 |
hidden_states, deterministic=deterministic
|
| 680 |
)
|
| 681 |
hidden_states = residual * res_gain + hidden_states
|
| 682 |
+
if self.config.ln_positions in ["postln"]:
|
| 683 |
hidden_states = norm(
|
| 684 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
| 685 |
)(hidden_states)
|
|
|
|
| 705 |
)
|
| 706 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
| 707 |
hidden_states = residual * res_gain + hidden_states
|
| 708 |
+
if self.add_norm or self.config.ln_positions in ["postln"]:
|
| 709 |
+
use_scale = (
|
| 710 |
+
self.use_scale
|
| 711 |
+
or self.config.ln_positions == "postln"
|
| 712 |
+
or self.config.use_all_scale
|
| 713 |
+
)
|
| 714 |
hidden_states = norm(
|
| 715 |
self.config.ln_type,
|
| 716 |
dtype=self.dtype,
|