small fix with torch.finfo
Browse files- modeling_lsg_roberta.py +92 -163
modeling_lsg_roberta.py
CHANGED
|
@@ -55,7 +55,8 @@ class LSGRobertaConfig(RobertaConfig):
|
|
| 55 |
|
| 56 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
| 57 |
logger.warning(
|
| 58 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'],
|
|
|
|
| 59 |
self.sparsity_type = None
|
| 60 |
|
| 61 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
@@ -71,7 +72,7 @@ class LSGRobertaConfig(RobertaConfig):
|
|
| 71 |
self.num_global_tokens = 1
|
| 72 |
elif self.num_global_tokens > 512:
|
| 73 |
logger.warning(
|
| 74 |
-
"[WARNING CONFIG]: num_global_tokens > 512 is not
|
| 75 |
)
|
| 76 |
self.num_global_tokens = 512
|
| 77 |
|
|
@@ -79,6 +80,16 @@ class LSGRobertaConfig(RobertaConfig):
|
|
| 79 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
| 80 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
class BaseSelfAttention(nn.Module):
|
| 84 |
|
|
@@ -187,7 +198,7 @@ class CausalAttentionProduct(nn.Module):
|
|
| 187 |
diagonal=-1
|
| 188 |
)
|
| 189 |
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
| 190 |
-
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
| 191 |
|
| 192 |
del attention_mask
|
| 193 |
|
|
@@ -436,39 +447,13 @@ class LSGRobertaEmbeddings(RobertaEmbeddings):
|
|
| 436 |
return embeddings
|
| 437 |
|
| 438 |
|
| 439 |
-
class LSGRobertaSelfOutput(RobertaSelfOutput):
|
| 440 |
-
|
| 441 |
-
def __init__(self, config):
|
| 442 |
-
super().__init__(config)
|
| 443 |
-
|
| 444 |
-
|
| 445 |
class LSGAttention(RobertaAttention):
|
| 446 |
|
| 447 |
def __init__(self, config):
|
| 448 |
|
| 449 |
-
|
| 450 |
|
| 451 |
self.self = LSGSelfAttention(config)
|
| 452 |
-
self.output = LSGRobertaSelfOutput(config)
|
| 453 |
-
self.pruned_heads = set()
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
class LSGRobertaIntermediate(RobertaIntermediate):
|
| 457 |
-
|
| 458 |
-
def __init__(self, config):
|
| 459 |
-
super().__init__(config)
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
class LSGRobertaOutput(RobertaOutput):
|
| 463 |
-
|
| 464 |
-
def __init__(self, config):
|
| 465 |
-
super().__init__(config)
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
class LSGRobertaPooler(RobertaPooler):
|
| 469 |
-
|
| 470 |
-
def __init__(self, config):
|
| 471 |
-
super().__init__(config)
|
| 472 |
|
| 473 |
|
| 474 |
class LSGSelfAttention(BaseSelfAttention):
|
|
@@ -561,7 +546,8 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 561 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
| 562 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
| 563 |
|
| 564 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
|
| 565 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 566 |
|
| 567 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
@@ -626,7 +612,8 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 626 |
keys /= mask + 1e-8
|
| 627 |
values /= mask + 1e-8
|
| 628 |
|
| 629 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
|
| 630 |
|
| 631 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
| 632 |
|
|
@@ -726,9 +713,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 726 |
attention_mask=attention_mask,
|
| 727 |
output_attentions=output_attentions
|
| 728 |
)
|
| 729 |
-
|
| 730 |
-
#if head_mask is not None:
|
| 731 |
-
# outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
|
| 732 |
return outputs
|
| 733 |
|
| 734 |
def causal_forward(
|
|
@@ -898,30 +883,87 @@ class LSGRobertaLayer(RobertaLayer):
|
|
| 898 |
|
| 899 |
def __init__(self, config):
|
| 900 |
|
| 901 |
-
|
| 902 |
|
| 903 |
-
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 904 |
-
self.seq_len_dim = 1
|
| 905 |
self.attention = LSGAttention(config)
|
| 906 |
-
self.is_decoder = config.is_decoder
|
| 907 |
-
self.add_cross_attention = config.add_cross_attention
|
| 908 |
if self.add_cross_attention:
|
| 909 |
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
| 910 |
self.crossattention = LSGAttention(config)
|
| 911 |
-
self.intermediate = LSGRobertaIntermediate(config)
|
| 912 |
-
self.output = LSGRobertaOutput(config)
|
| 913 |
|
| 914 |
|
| 915 |
class LSGRobertaEncoder(RobertaEncoder):
|
| 916 |
|
| 917 |
def __init__(self, config):
|
| 918 |
|
| 919 |
-
|
| 920 |
|
| 921 |
-
self.config = config
|
| 922 |
self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
|
| 923 |
-
self.gradient_checkpointing = False
|
| 924 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 925 |
|
| 926 |
class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
|
| 927 |
"""
|
|
@@ -945,23 +987,13 @@ class LSGRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
|
|
| 945 |
config_class = LSGRobertaConfig
|
| 946 |
|
| 947 |
|
| 948 |
-
def __init__(self, config, add_pooling_layer=
|
| 949 |
|
| 950 |
LSGRobertaPreTrainedModel.__init__(self, config)
|
| 951 |
|
| 952 |
-
assert hasattr(config, "num_global_tokens")
|
| 953 |
-
self.num_global_tokens = config.num_global_tokens
|
| 954 |
-
self.pad_idx = config.pad_token_id
|
| 955 |
-
|
| 956 |
-
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
| 957 |
-
self.block_size = config.block_size
|
| 958 |
-
self.adaptive = config.adaptive
|
| 959 |
-
self.mask_first_token = config.mask_first_token
|
| 960 |
-
self.pool_with_global = config.pool_with_global
|
| 961 |
-
|
| 962 |
self.embeddings = LSGRobertaEmbeddings(config)
|
| 963 |
self.encoder = LSGRobertaEncoder(config)
|
| 964 |
-
self.pooler =
|
| 965 |
|
| 966 |
if config.add_cross_attention:
|
| 967 |
logger.warning(
|
|
@@ -971,95 +1003,6 @@ class LSGRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
|
|
| 971 |
# Initialize weights and apply final processing
|
| 972 |
self.post_init()
|
| 973 |
|
| 974 |
-
def forward(
|
| 975 |
-
self,
|
| 976 |
-
input_ids=None,
|
| 977 |
-
attention_mask=None,
|
| 978 |
-
token_type_ids=None,
|
| 979 |
-
position_ids=None,
|
| 980 |
-
head_mask=None,
|
| 981 |
-
inputs_embeds=None,
|
| 982 |
-
encoder_hidden_states=None,
|
| 983 |
-
encoder_attention_mask=None,
|
| 984 |
-
past_key_values=None,
|
| 985 |
-
use_cache=None,
|
| 986 |
-
output_attentions=None,
|
| 987 |
-
output_hidden_states=None,
|
| 988 |
-
return_dict=None
|
| 989 |
-
):
|
| 990 |
-
|
| 991 |
-
inputs_ = input_ids if input_ids is not None else inputs_embeds
|
| 992 |
-
n, t = inputs_.size()[:2]
|
| 993 |
-
|
| 994 |
-
if attention_mask is None:
|
| 995 |
-
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
| 996 |
-
if self.mask_first_token:
|
| 997 |
-
attention_mask[:,0] = 0
|
| 998 |
-
|
| 999 |
-
b = self.block_size * 2
|
| 1000 |
-
pad = t % self.block_size
|
| 1001 |
-
|
| 1002 |
-
# Check if t is multiple of block_size and pad
|
| 1003 |
-
if self.adaptive and t > b and pad > 0:
|
| 1004 |
-
pad_length = self.block_size - pad
|
| 1005 |
-
if input_ids is not None:
|
| 1006 |
-
input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
|
| 1007 |
-
else:
|
| 1008 |
-
inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
|
| 1009 |
-
|
| 1010 |
-
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
|
| 1011 |
-
|
| 1012 |
-
if token_type_ids is not None:
|
| 1013 |
-
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
|
| 1014 |
-
if position_ids is not None:
|
| 1015 |
-
position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
|
| 1016 |
-
|
| 1017 |
-
n, t_ = attention_mask.size()
|
| 1018 |
-
|
| 1019 |
-
encoder_outputs = super().forward(
|
| 1020 |
-
input_ids=input_ids,
|
| 1021 |
-
attention_mask=attention_mask,
|
| 1022 |
-
token_type_ids=token_type_ids,
|
| 1023 |
-
position_ids=position_ids,
|
| 1024 |
-
head_mask=head_mask,
|
| 1025 |
-
inputs_embeds=inputs_embeds,
|
| 1026 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 1027 |
-
encoder_attention_mask=encoder_attention_mask,
|
| 1028 |
-
past_key_values=past_key_values,
|
| 1029 |
-
use_cache=use_cache,
|
| 1030 |
-
output_attentions=output_attentions,
|
| 1031 |
-
output_hidden_states=output_hidden_states,
|
| 1032 |
-
return_dict=return_dict
|
| 1033 |
-
)
|
| 1034 |
-
|
| 1035 |
-
context = encoder_outputs[0]
|
| 1036 |
-
if self.pool_with_global:
|
| 1037 |
-
context[:, self.num_global_tokens] = context[:, 0]
|
| 1038 |
-
|
| 1039 |
-
diff = t - t_
|
| 1040 |
-
n, _, d = context.size()
|
| 1041 |
-
context = context[..., self.num_global_tokens:, :]
|
| 1042 |
-
|
| 1043 |
-
# Adapt sequence to initial shape
|
| 1044 |
-
if diff < 0:
|
| 1045 |
-
context = context[:, :t]
|
| 1046 |
-
|
| 1047 |
-
encoder_outputs.last_hidden_state = context
|
| 1048 |
-
sequence_output = encoder_outputs[0]
|
| 1049 |
-
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 1050 |
-
|
| 1051 |
-
if not return_dict:
|
| 1052 |
-
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 1053 |
-
|
| 1054 |
-
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 1055 |
-
last_hidden_state=sequence_output,
|
| 1056 |
-
pooler_output=pooled_output,
|
| 1057 |
-
past_key_values=encoder_outputs.past_key_values,
|
| 1058 |
-
hidden_states=encoder_outputs.hidden_states,
|
| 1059 |
-
attentions=encoder_outputs.attentions,
|
| 1060 |
-
cross_attentions=encoder_outputs.cross_attentions,
|
| 1061 |
-
)
|
| 1062 |
-
|
| 1063 |
def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
|
| 1064 |
|
| 1065 |
# Do not rely on original triangular mask from BERT/RoBERTa for causalLM
|
|
@@ -1092,7 +1035,7 @@ class LSGRobertaForCausalLM(LSGRobertaPreTrainedModel, RobertaForCausalLM):
|
|
| 1092 |
logger.warning("If you want to use `LSGRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
|
| 1093 |
|
| 1094 |
self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
|
| 1095 |
-
self.lm_head =
|
| 1096 |
|
| 1097 |
# The LM head weights require special treatment only when they are tied with the word embeddings
|
| 1098 |
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
|
|
@@ -1122,7 +1065,7 @@ class LSGRobertaForMaskedLM(LSGRobertaPreTrainedModel, RobertaForMaskedLM):
|
|
| 1122 |
)
|
| 1123 |
|
| 1124 |
self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
|
| 1125 |
-
self.lm_head =
|
| 1126 |
|
| 1127 |
# The LM head weights require special treatment only when they are tied with the word embeddings
|
| 1128 |
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
|
|
@@ -1131,13 +1074,6 @@ class LSGRobertaForMaskedLM(LSGRobertaPreTrainedModel, RobertaForMaskedLM):
|
|
| 1131 |
self.post_init()
|
| 1132 |
|
| 1133 |
|
| 1134 |
-
class LSGRobertaLMHead(RobertaLMHead):
|
| 1135 |
-
"""LSG Head for masked language modeling."""
|
| 1136 |
-
|
| 1137 |
-
def __init__(self, config):
|
| 1138 |
-
super().__init__(config)
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
class LSGRobertaForSequenceClassification(LSGRobertaPreTrainedModel, RobertaForSequenceClassification):
|
| 1142 |
"""
|
| 1143 |
This class overrides :class:`~transformers.RobertaForSequenceClassification`. Please check the superclass for the
|
|
@@ -1154,19 +1090,12 @@ class LSGRobertaForSequenceClassification(LSGRobertaPreTrainedModel, RobertaForS
|
|
| 1154 |
self.config = config
|
| 1155 |
|
| 1156 |
self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
|
| 1157 |
-
self.classifier =
|
| 1158 |
|
| 1159 |
# Initialize weights and apply final processing
|
| 1160 |
self.post_init()
|
| 1161 |
|
| 1162 |
|
| 1163 |
-
class LSGRobertaClassificationHead(RobertaClassificationHead):
|
| 1164 |
-
"""Head for sentence-level classification tasks."""
|
| 1165 |
-
|
| 1166 |
-
def __init__(self, config):
|
| 1167 |
-
super().__init__(config)
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
class LSGRobertaForMultipleChoice(LSGRobertaPreTrainedModel, RobertaForMultipleChoice):
|
| 1171 |
"""
|
| 1172 |
This class overrides :class:`~transformers.RobertaForMultipleChoice`. Please check the superclass for the
|
|
|
|
| 55 |
|
| 56 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
| 57 |
logger.warning(
|
| 58 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
| 59 |
+
setting sparsity_type=None, computation will skip sparse attention")
|
| 60 |
self.sparsity_type = None
|
| 61 |
|
| 62 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
|
|
| 72 |
self.num_global_tokens = 1
|
| 73 |
elif self.num_global_tokens > 512:
|
| 74 |
logger.warning(
|
| 75 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
| 76 |
)
|
| 77 |
self.num_global_tokens = 512
|
| 78 |
|
|
|
|
| 80 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
| 81 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
| 82 |
|
| 83 |
+
if self.mask_first_token and not pool_with_global:
|
| 84 |
+
logger.warning(
|
| 85 |
+
"[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
|
| 86 |
+
self.pool_with_global = True
|
| 87 |
+
|
| 88 |
+
if hasattr(self, "position_embedding_type"):
|
| 89 |
+
if self.position_embedding_type != "absolute":
|
| 90 |
+
logger.warning(
|
| 91 |
+
"[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
|
| 92 |
+
|
| 93 |
|
| 94 |
class BaseSelfAttention(nn.Module):
|
| 95 |
|
|
|
|
| 198 |
diagonal=-1
|
| 199 |
)
|
| 200 |
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
| 201 |
+
attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
|
| 202 |
|
| 203 |
del attention_mask
|
| 204 |
|
|
|
|
| 447 |
return embeddings
|
| 448 |
|
| 449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
class LSGAttention(RobertaAttention):
|
| 451 |
|
| 452 |
def __init__(self, config):
|
| 453 |
|
| 454 |
+
super().__init__(config)
|
| 455 |
|
| 456 |
self.self = LSGSelfAttention(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
|
| 459 |
class LSGSelfAttention(BaseSelfAttention):
|
|
|
|
| 546 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
| 547 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
| 548 |
|
| 549 |
+
mask = (1. - mask.clamp(0, 1))
|
| 550 |
+
mask *= torch.finfo(mask.dtype).min
|
| 551 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 552 |
|
| 553 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
|
| 612 |
keys /= mask + 1e-8
|
| 613 |
values /= mask + 1e-8
|
| 614 |
|
| 615 |
+
mask = (1. - mask.clamp(0, 1))
|
| 616 |
+
mask *= torch.finfo(mask.dtype).min
|
| 617 |
|
| 618 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
| 619 |
|
|
|
|
| 713 |
attention_mask=attention_mask,
|
| 714 |
output_attentions=output_attentions
|
| 715 |
)
|
| 716 |
+
|
|
|
|
|
|
|
| 717 |
return outputs
|
| 718 |
|
| 719 |
def causal_forward(
|
|
|
|
| 883 |
|
| 884 |
def __init__(self, config):
|
| 885 |
|
| 886 |
+
super().__init__(config)
|
| 887 |
|
|
|
|
|
|
|
| 888 |
self.attention = LSGAttention(config)
|
|
|
|
|
|
|
| 889 |
if self.add_cross_attention:
|
| 890 |
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
| 891 |
self.crossattention = LSGAttention(config)
|
|
|
|
|
|
|
| 892 |
|
| 893 |
|
| 894 |
class LSGRobertaEncoder(RobertaEncoder):
|
| 895 |
|
| 896 |
def __init__(self, config):
|
| 897 |
|
| 898 |
+
super().__init__(config)
|
| 899 |
|
|
|
|
| 900 |
self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
|
|
|
|
| 901 |
|
| 902 |
+
assert hasattr(config, "num_global_tokens")
|
| 903 |
+
self.num_global_tokens = config.num_global_tokens
|
| 904 |
+
self.pad_idx = config.pad_token_id
|
| 905 |
+
|
| 906 |
+
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
| 907 |
+
self.block_size = config.block_size
|
| 908 |
+
self.adaptive = config.adaptive
|
| 909 |
+
self.mask_first_token = config.mask_first_token
|
| 910 |
+
self.pool_with_global = config.pool_with_global
|
| 911 |
+
|
| 912 |
+
def forward(
|
| 913 |
+
self,
|
| 914 |
+
hidden_states: torch.Tensor,
|
| 915 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 916 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 917 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 918 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 919 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 920 |
+
use_cache: Optional[bool] = None,
|
| 921 |
+
output_attentions: Optional[bool] = False,
|
| 922 |
+
output_hidden_states: Optional[bool] = False,
|
| 923 |
+
return_dict: Optional[bool] = True,
|
| 924 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 925 |
+
|
| 926 |
+
mask_value = torch.finfo(attention_mask.dtype).min
|
| 927 |
+
n, _, __, t = attention_mask.size()
|
| 928 |
+
|
| 929 |
+
if not (self.config.is_decoder and encoder_hidden_states is not None):
|
| 930 |
+
b = self.block_size * 2
|
| 931 |
+
pad = t % self.block_size
|
| 932 |
+
|
| 933 |
+
# Check if t is multiple of block_size and pad
|
| 934 |
+
if self.adaptive and t > b and pad > 0:
|
| 935 |
+
pad_length = self.block_size - pad
|
| 936 |
+
hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
|
| 937 |
+
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
|
| 938 |
+
|
| 939 |
+
if self.mask_first_token:
|
| 940 |
+
attention_mask[..., 0] = mask_value
|
| 941 |
+
|
| 942 |
+
encoder_outputs = super().forward(
|
| 943 |
+
hidden_states=hidden_states,
|
| 944 |
+
attention_mask=attention_mask,
|
| 945 |
+
head_mask=head_mask,
|
| 946 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 947 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 948 |
+
past_key_values=past_key_values,
|
| 949 |
+
use_cache=use_cache,
|
| 950 |
+
output_attentions=output_attentions,
|
| 951 |
+
output_hidden_states=output_hidden_states,
|
| 952 |
+
return_dict=return_dict
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
sequence_output = encoder_outputs[0]
|
| 956 |
+
if self.pool_with_global:
|
| 957 |
+
sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
|
| 958 |
+
|
| 959 |
+
# Adapt sequence to initial shape
|
| 960 |
+
sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
|
| 961 |
+
|
| 962 |
+
if not return_dict:
|
| 963 |
+
return (sequence_output, ) + encoder_outputs[1:]
|
| 964 |
+
|
| 965 |
+
encoder_outputs.last_hidden_state = sequence_output
|
| 966 |
+
return encoder_outputs
|
| 967 |
|
| 968 |
class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
|
| 969 |
"""
|
|
|
|
| 987 |
config_class = LSGRobertaConfig
|
| 988 |
|
| 989 |
|
| 990 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 991 |
|
| 992 |
LSGRobertaPreTrainedModel.__init__(self, config)
|
| 993 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 994 |
self.embeddings = LSGRobertaEmbeddings(config)
|
| 995 |
self.encoder = LSGRobertaEncoder(config)
|
| 996 |
+
self.pooler = RobertaPooler(config) if add_pooling_layer else None
|
| 997 |
|
| 998 |
if config.add_cross_attention:
|
| 999 |
logger.warning(
|
|
|
|
| 1003 |
# Initialize weights and apply final processing
|
| 1004 |
self.post_init()
|
| 1005 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
|
| 1007 |
|
| 1008 |
# Do not rely on original triangular mask from BERT/RoBERTa for causalLM
|
|
|
|
| 1035 |
logger.warning("If you want to use `LSGRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
|
| 1036 |
|
| 1037 |
self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
|
| 1038 |
+
self.lm_head = RobertaLMHead(config)
|
| 1039 |
|
| 1040 |
# The LM head weights require special treatment only when they are tied with the word embeddings
|
| 1041 |
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
|
|
|
|
| 1065 |
)
|
| 1066 |
|
| 1067 |
self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
|
| 1068 |
+
self.lm_head = RobertaLMHead(config)
|
| 1069 |
|
| 1070 |
# The LM head weights require special treatment only when they are tied with the word embeddings
|
| 1071 |
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
|
|
|
|
| 1074 |
self.post_init()
|
| 1075 |
|
| 1076 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
class LSGRobertaForSequenceClassification(LSGRobertaPreTrainedModel, RobertaForSequenceClassification):
|
| 1078 |
"""
|
| 1079 |
This class overrides :class:`~transformers.RobertaForSequenceClassification`. Please check the superclass for the
|
|
|
|
| 1090 |
self.config = config
|
| 1091 |
|
| 1092 |
self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
|
| 1093 |
+
self.classifier = RobertaClassificationHead(config)
|
| 1094 |
|
| 1095 |
# Initialize weights and apply final processing
|
| 1096 |
self.post_init()
|
| 1097 |
|
| 1098 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1099 |
class LSGRobertaForMultipleChoice(LSGRobertaPreTrainedModel, RobertaForMultipleChoice):
|
| 1100 |
"""
|
| 1101 |
This class overrides :class:`~transformers.RobertaForMultipleChoice`. Please check the superclass for the
|