ccdv commited on
Commit
b3c2a0d
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - long context
5
+ ---
6
+
7
+ # LSG model
8
+ **Transformers >= 4.18.0**\
9
+ **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
10
+ **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
11
+
12
+ * [Usage](#usage)
13
+ * [Parameters](#parameters)
14
+ * [Sparse selection type](#sparse-selection-type)
15
+ * [Tasks](#tasks)
16
+ * [Training global tokens](#training-global-tokens)
17
+
18
+
19
+ This model is adapted from [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) without additional pretraining yet. It uses the same number of parameters/layers and the same tokenizer
20
+
21
+ This model can handle long sequences but faster and more efficiently than Longformer or BigBird (from Transformers) and relies on Local + Sparse + Global attention (LSG).
22
+
23
+
24
+ The model requires sequences whose length is a multiple of the block size. The model is "adaptive" and automatically pads the sequences if needed (adaptive=True in config). It is however recommended, thanks to the tokenizer, to truncate the inputs (truncation=True) and optionally to pad with a multiple of the block size (pad_to_multiple_of=...). \
25
+
26
+
27
+ Support encoder-decoder and causal masking but I didnt test it extensively.\
28
+ Implemented in PyTorch.
29
+
30
+ ![attn](attn.png)
31
+
32
+ ## Usage
33
+ The model relies on a custom modeling file, you need to add trust_remote_code=True to use it.
34
+
35
+ ```python:
36
+ from transformers import AutoModel, AutoTokenizer
37
+
38
+ model = AutoModel.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096", trust_remote_code=True)
39
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096")
40
+ ```
41
+
42
+ ## Parameters
43
+ You can change various parameters like :
44
+ * the number of global tokens (num_global_tokens=1)
45
+ * local block size (block_size=128)
46
+ * sparse block size (sparse_block_size=128)
47
+ * sparsity factor (sparsity_factor=2)
48
+ * see config.json file
49
+
50
+ Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
51
+
52
+ ```python:
53
+ model = AutoModel.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096",
54
+ trust_remote_code=True,
55
+ num_global_tokens=16,
56
+ block_size=64,
57
+ sparse_block_size=64,
58
+ sparsity_factor=4,
59
+ attention_probs_dropout_prob=0.0
60
+ )
61
+ ```
62
+
63
+ ## Sparse selection type
64
+
65
+ There are 5 different sparse selection patterns. The best type is task dependent. \
66
+ Note that for sequences with length < 2*block_size, the type has no effect.
67
+
68
+ * sparsity_type="norm", select highest norm tokens
69
+ * Works best for a small sparsity_factor (2 to 4)
70
+ * Additional parameters:
71
+ * None
72
+ * sparsity_type="pooling", use average pooling to merge tokens
73
+ * Works best for a small sparsity_factor (2 to 4)
74
+ * Additional parameters:
75
+ * None
76
+ * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
77
+ * Works best for a large sparsity_factor (4+)
78
+ * LSH relies on random projections, thus inference may differ slightly with different seeds
79
+ * Additional parameters:
80
+ * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
81
+ * sparsity_type="stride", use a striding mecanism per head
82
+ * Each head will use different tokens strided by sparsify_factor
83
+ * Not recommended if sparsify_factor > num_heads
84
+ * sparsity_type="block_stride", use a striding mecanism per head
85
+ * Each head will use block of tokens strided by sparsify_factor
86
+ * Not recommended if sparsify_factor > num_heads
87
+
88
+
89
+ ## Tasks
90
+ Fill mask example:
91
+ ```python:
92
+ from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer
93
+
94
+ model = AutoModelForMaskedLM.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096", trust_remote_code=True)
95
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096")
96
+
97
+ SENTENCES = ["Paris is the <mask> of France.", "The goal of life is <mask>."]
98
+ pipeline = FillMaskPipeline(model, tokenizer)
99
+ output = pipeline(SENTENCES, top_k=1)
100
+
101
+ output = [o[0]["sequence"] for o in output]
102
+ > ['Paris is the capital of France.', 'The goal of life is happiness.']
103
+ ```
104
+
105
+
106
+ Classification example:
107
+ ```python:
108
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
109
+
110
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096",
111
+ trust_remote_code=True,
112
+ pool_with_global=True, # pool with a global token instead of first token
113
+ )
114
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096")
115
+
116
+ SENTENCE = "This is a test for sequence classification. " * 300
117
+ token_ids = tokenizer(
118
+ SENTENCE,
119
+ return_tensors="pt",
120
+ #pad_to_multiple_of=... # Optional
121
+ truncation=True
122
+ )
123
+ output = model(**token_ids)
124
+
125
+ > SequenceClassifierOutput(loss=None, logits=tensor([[-0.3051, -0.1762]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)
126
+ ```
127
+
128
+ ## Training global tokens
129
+ To train global tokens and the classification head only:
130
+ ```python:
131
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
132
+
133
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096",
134
+ trust_remote_code=True,
135
+ pool_with_global=True, # pool with a global token instead of first token
136
+ num_global_tokens=16
137
+ )
138
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096")
139
+
140
+ for name, param in model.named_parameters():
141
+ if "global_embeddings" not in name:
142
+ param.requires_grad = False
143
+ else:
144
+ param.required_grad = True
145
+ ```
attn.png ADDED
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ccdv/lsg-distilbert-base-uncased-4096",
3
+ "activation": "gelu",
4
+ "adaptive": true,
5
+ "architectures": [
6
+ "LSGDistilBertForMaskedLM"
7
+ ],
8
+ "attention_dropout": 0.1,
9
+ "auto_map": {
10
+ "AutoConfig": "modeling_lsg_distilbert.LSGDistilBertConfig",
11
+ "AutoModel": "modeling_lsg_distilbert.LSGDistilBertModel",
12
+ "AutoModelForMaskedLM": "modeling_lsg_distilbert.LSGDistilBertForMaskedLM",
13
+ "AutoModelForMultipleChoice": "modeling_lsg_distilbert.LSGDistilBertForMultipleChoice",
14
+ "AutoModelForQuestionAnswering": "modeling_lsg_distilbert.LSGDistilBertForQuestionAnswering",
15
+ "AutoModelForSequenceClassification": "modeling_lsg_distilbert.LSGDistilBertForSequenceClassification",
16
+ "AutoModelForTokenClassification": "modeling_lsg_distilbert.LSGDistilBertForTokenClassification"
17
+ },
18
+ "base_model_prefix": "lsg",
19
+ "block_size": 128,
20
+ "dim": 768,
21
+ "dropout": 0.1,
22
+ "hidden_dim": 3072,
23
+ "initializer_range": 0.02,
24
+ "lsh_num_pre_rounds": 1,
25
+ "max_position_embeddings": 4096,
26
+ "model_type": "distilbert",
27
+ "n_heads": 12,
28
+ "n_layers": 6,
29
+ "num_global_tokens": 1,
30
+ "pad_token_id": 0,
31
+ "pool_with_global": true,
32
+ "qa_dropout": 0.1,
33
+ "seq_classif_dropout": 0.2,
34
+ "sinusoidal_pos_embds": false,
35
+ "sparse_block_size": 128,
36
+ "sparsity_factor": 2,
37
+ "sparsity_type": "norm",
38
+ "tie_weights_": true,
39
+ "torch_dtype": "float32",
40
+ "transformers_version": "4.19.2",
41
+ "vocab_size": 30522
42
+ }
modeling_lsg_distilbert.py ADDED
@@ -0,0 +1,1113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warn
2
+ from transformers.models.distilbert.modeling_distilbert import *
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.models.distilbert.configuration_distilbert import DistilBertConfig
6
+ import sys
7
+
8
+
9
+ AUTO_MAP = {
10
+ "AutoModel": "modeling_lsg_distilbert.LSGDistilBertModel",
11
+ "AutoModelForMaskedLM": "modeling_lsg_distilbert.LSGDistilBertForMaskedLM",
12
+ "AutoModelForMultipleChoice": "modeling_lsg_distilbert.LSGDistilBertForMultipleChoice",
13
+ "AutoModelForQuestionAnswering": "modeling_lsg_distilbert.LSGDistilBertForQuestionAnswering",
14
+ "AutoModelForSequenceClassification": "modeling_lsg_distilbert.LSGDistilBertForSequenceClassification",
15
+ "AutoModelForTokenClassification": "modeling_lsg_distilbert.LSGDistilBertForTokenClassification"
16
+ }
17
+
18
+
19
+ class LSGDistilBertConfig(DistilBertConfig):
20
+
21
+ base_model_prefix = "lsg"
22
+ model_type = "distilbert"
23
+
24
+ def __init__(
25
+ self,
26
+ adaptive=True,
27
+ base_model_prefix="lsg",
28
+ block_size=128,
29
+ lsh_num_pre_rounds=1,
30
+ num_global_tokens=1,
31
+ pool_with_global=True,
32
+ sparse_block_size=128,
33
+ sparsity_factor=2,
34
+ sparsity_type="norm",
35
+ **kwargs
36
+ ):
37
+ """Constructs LSGDistilBertConfig."""
38
+ super().__init__(**kwargs)
39
+
40
+ self.adaptive = adaptive
41
+ self.auto_map = AUTO_MAP
42
+ self.base_model_prefix = base_model_prefix
43
+ self.block_size = block_size
44
+ self.lsh_num_pre_rounds = lsh_num_pre_rounds
45
+ self.num_global_tokens = num_global_tokens
46
+ self.pool_with_global = pool_with_global
47
+ self.sparse_block_size = sparse_block_size
48
+ self.sparsity_factor = sparsity_factor
49
+ self.sparsity_type = sparsity_type
50
+
51
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
52
+ logger.warning(
53
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
54
+ self.sparsity_type = None
55
+
56
+ if self.sparsity_type in ["stride", "block_stride"]:
57
+ if self.sparsity_factor > self.encoder_attention_heads:
58
+ logger.warning(
59
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
60
+ )
61
+
62
+ if self.num_global_tokens < 1:
63
+ logger.warning(
64
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
65
+ )
66
+ self.num_global_tokens = 1
67
+ elif self.num_global_tokens > 512:
68
+ logger.warning(
69
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
70
+ )
71
+ self.num_global_tokens = 512
72
+
73
+ if self.sparsity_factor > 0:
74
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
75
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
76
+
77
+
78
+ class LSGEmbeddings(Embeddings):
79
+
80
+ def __init__(self, config):
81
+
82
+ super().__init__(config)
83
+ self.num_global_tokens = config.num_global_tokens
84
+
85
+ # Hardcoded but partially trained
86
+ self.global_embeddings = nn.Embedding(512, embedding_dim=config.dim, )
87
+
88
+ self.block_size = config.block_size
89
+
90
+ def forward(self, input_ids, inputs_embeds=None):
91
+ """
92
+ Parameters:
93
+ input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
94
+ Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
95
+ embeddings)
96
+ """
97
+ bs, seq_length = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2]
98
+
99
+ # Setting the position-ids to the registered buffer in constructor, it helps
100
+ # when tracing the model without passing position-ids, solves
101
+ # isues similar to issue #5664
102
+ if hasattr(self, "position_ids"):
103
+ position_ids = self.position_ids[:, :seq_length]
104
+ else:
105
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
106
+ position_ids = position_ids.unsqueeze(0).expand(bs, seq_length) # (bs, max_seq_length)
107
+
108
+ word_embeddings = self.word_embeddings(input_ids) if input_ids is not None else inputs_embeds
109
+ position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
110
+ word_embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
111
+
112
+ #if self.num_global_tokens < 0:
113
+ n, t, d = word_embeddings.size()
114
+
115
+ # Add global_tokens
116
+ indexes = torch.arange(self.num_global_tokens, device=word_embeddings.device).reshape(1, -1)
117
+ global_embeddings = self.global_embeddings(indexes)
118
+ word_embeddings = torch.cat([global_embeddings.expand(n, -1, d), word_embeddings], dim=-2)
119
+
120
+ word_embeddings = self.LayerNorm(word_embeddings) # (bs, max_seq_length, dim)
121
+ word_embeddings = self.dropout(word_embeddings) # (bs, max_seq_length, dim)
122
+ return word_embeddings
123
+
124
+
125
+ class BaseSelfAttention(nn.Module):
126
+
127
+ def init_modules(self, config):
128
+ if config.dim % config.n_heads != 0 and not hasattr(
129
+ config, "embedding_size"
130
+ ):
131
+ raise ValueError(
132
+ "The hidden size (%d) is not a multiple of the number of attention "
133
+ "heads (%d)" % (config.dim, config.n_heads)
134
+ )
135
+
136
+ self.n_heads = config.n_heads
137
+ self.attention_head_size = int(config.dim / config.n_heads)
138
+ self.all_head_size = self.n_heads * self.attention_head_size
139
+
140
+ self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
141
+ self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
142
+ self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
143
+ self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
144
+
145
+ self.dropout = nn.Dropout(config.attention_dropout)
146
+
147
+ def transpose_for_scores(self, x):
148
+ new_x_shape = x.size()[:-1] + (
149
+ self.n_heads,
150
+ self.attention_head_size,
151
+ )
152
+ x = x.view(*new_x_shape)
153
+ return x.permute(0, 2, 1, 3)
154
+
155
+ def reshape_output(self, context_layer):
156
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
157
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
158
+ return context_layer.view(*new_context_layer_shape)
159
+
160
+ def project_QKV(self, hidden_states):
161
+
162
+ query_layer = self.transpose_for_scores(self.q_lin(hidden_states))
163
+ key_layer = self.transpose_for_scores(self.k_lin(hidden_states))
164
+ value_layer = self.transpose_for_scores(self.v_lin(hidden_states))
165
+ return query_layer, key_layer, value_layer
166
+
167
+
168
+ class BaseAttentionProduct(nn.Module):
169
+
170
+ def __init__(self, config):
171
+ """
172
+ Compute attention: softmax(Q @ K.T) @ V
173
+ """
174
+ super().__init__()
175
+ self.dropout = nn.Dropout(config.attention_dropout)
176
+
177
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
178
+
179
+ d = query_layer.shape[-1]
180
+
181
+ # Take the dot product between "query" and "key" to get the raw attention scores.
182
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
183
+
184
+ del query_layer
185
+ del key_layer
186
+
187
+ if attention_mask is not None:
188
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
189
+ attention_scores = attention_scores + attention_mask
190
+ del attention_mask
191
+
192
+ # Normalize the attention scores to probabilities.
193
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
194
+
195
+ # This is actually dropping out entire tokens to attend to, which might
196
+ # seem a bit unusual, but is taken from the original Transformer paper.
197
+ context_layer = self.dropout(attention_probs) @ value_layer
198
+
199
+ return context_layer
200
+
201
+
202
+ class CausalAttentionProduct(nn.Module):
203
+
204
+ def __init__(self, config):
205
+ """
206
+ Compute attention: softmax(Q @ K.T) @ V
207
+ """
208
+ super().__init__()
209
+ self.dropout = nn.Dropout(config.attention_dropout)
210
+ self.block_size = config.block_size
211
+
212
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None, causal_shape=None):
213
+
214
+ d = query_layer.shape[-1]
215
+
216
+ # Take the dot product between "query" and "key" to get the raw attention scores.
217
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
218
+
219
+ del query_layer
220
+ del key_layer
221
+
222
+ if attention_mask is not None:
223
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
224
+ attention_scores = attention_scores + attention_mask
225
+
226
+ # Add causal mask
227
+ causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
228
+ causal_mask = torch.tril(torch.ones(*causal_shape, device=attention_mask.device), diagonal=-1).T * (-10000)
229
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
230
+
231
+ del attention_mask
232
+
233
+ # Normalize the attention scores to probabilities.
234
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
235
+
236
+ # This is actually dropping out entire tokens to attend to, which might
237
+ # seem a bit unusual, but is taken from the original Transformer paper.
238
+ context_layer = self.dropout(attention_probs) @ value_layer
239
+
240
+ return context_layer
241
+
242
+
243
+ class LSGAttentionProduct(nn.Module):
244
+
245
+ def __init__(self, config, block_size=None, sparse_block_size=None, sparsity_factor=4, is_causal=False):
246
+ """
247
+ Compute block or overlapping blocks attention products
248
+ """
249
+ super().__init__()
250
+
251
+ self.block_size = block_size
252
+ self.sparse_block_size = sparse_block_size
253
+ self.sparsity_factor = sparsity_factor
254
+ self.is_causal = is_causal
255
+
256
+ if self.block_size is None:
257
+ self.block_size = config.block_size
258
+
259
+ if self.sparse_block_size is None:
260
+ self.sparse_block_size = config.sparse_block_size
261
+
262
+ # Shape of blocks
263
+ self.local_shapes = (self.block_size*3, self.block_size)
264
+ if self.sparse_block_size and self.sparsity_factor > 0:
265
+ self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
266
+
267
+ if is_causal:
268
+ self.attention = CausalAttentionProduct(config)
269
+ else:
270
+ self.attention = BaseAttentionProduct(config)
271
+
272
+ def build_lsg_inputs(self, hidden_states, sparse_hidden_states, global_hidden_states, is_attn_mask=False):
273
+
274
+ # Build local tokens
275
+ local_hidden_states = self.reshape_to_local_block(hidden_states, is_attn_mask)
276
+ del hidden_states
277
+
278
+ # Build sparse tokens
279
+ if sparse_hidden_states is not None:
280
+ sparse_hidden_states = self.reshape_to_sparse_block(sparse_hidden_states, is_attn_mask)
281
+
282
+ return self.cat_global_sparse_local_tokens(global_hidden_states, sparse_hidden_states, local_hidden_states)
283
+
284
+ def forward(
285
+ self,
286
+ query_layer,
287
+ key_layer,
288
+ value_layer,
289
+ attention_mask=None,
290
+ sparse_key=None,
291
+ sparse_value=None,
292
+ sparse_mask=None,
293
+ global_key=None,
294
+ global_value=None,
295
+ global_mask=None
296
+ ):
297
+
298
+ # Input batch, heads, length, dim
299
+ n, h, t, d = query_layer.size()
300
+ n_blocks = t // self.block_size
301
+ assert t % self.block_size == 0
302
+
303
+ key_layer = self.build_lsg_inputs(
304
+ key_layer,
305
+ sparse_key,
306
+ global_key
307
+ )
308
+ del sparse_key
309
+ del global_key
310
+
311
+ value_layer = self.build_lsg_inputs(
312
+ value_layer,
313
+ sparse_value,
314
+ global_value
315
+ )
316
+ del sparse_value
317
+ del global_value
318
+
319
+ attention_mask = self.build_lsg_inputs(
320
+ attention_mask,
321
+ sparse_mask,
322
+ global_mask.transpose(-1, -2),
323
+ is_attn_mask=True
324
+ ).transpose(-1, -2)
325
+ del sparse_mask
326
+ del global_mask
327
+
328
+ # expect (..., t, d) shape
329
+ # Compute attention
330
+ context_layer = self.attention(
331
+ query_layer=self.chunk(query_layer, n_blocks),
332
+ key_layer=key_layer,
333
+ value_layer=value_layer,
334
+ attention_mask=attention_mask
335
+ )
336
+
337
+ return context_layer.reshape(n, h, -1, d)
338
+
339
+ def reshape_to_local_block(self, hidden_states, is_attn_mask=False):
340
+
341
+ size, step = self.local_shapes
342
+ s = (size - step) // 2
343
+
344
+ # Pad before block reshaping
345
+ if is_attn_mask:
346
+ pad_value = -10000
347
+ hidden_states = hidden_states.transpose(-1, -2)
348
+ else:
349
+ pad_value = 0
350
+
351
+ hidden_states = torch.nn.functional.pad(
352
+ hidden_states.transpose(-1, -2),
353
+ pad=(s, s),
354
+ value=pad_value
355
+ ).transpose(-1, -2)
356
+
357
+ # Make blocks
358
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
359
+
360
+ # Skip third block if causal
361
+ if self.is_causal:
362
+ return hidden_states[..., :size*2//3, :]
363
+
364
+ return hidden_states
365
+
366
+ def reshape_to_sparse_block(self, hidden_states, is_attn_mask=False):
367
+
368
+ size, step = self.sparse_shapes
369
+
370
+ # In case of odd case
371
+ odd_offset = (step % 2)
372
+
373
+ # n, h, t, d*2 + 1
374
+ size = size*2
375
+ s = (size - step) // 2 + odd_offset
376
+
377
+ # Pad before block reshaping
378
+ if is_attn_mask:
379
+ pad_value = -10000
380
+ hidden_states = hidden_states.transpose(-1, -2)
381
+ else:
382
+ pad_value = 0
383
+
384
+ hidden_states = torch.nn.functional.pad(
385
+ hidden_states.transpose(-1, -2),
386
+ pad=(s, s),
387
+ value=pad_value
388
+ ).transpose(-1, -2)
389
+
390
+ # Make blocks
391
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
392
+
393
+ # Fix case where block_size == sparsify_factor
394
+ if odd_offset:
395
+ hidden_states = hidden_states[..., :-1, :, :]
396
+
397
+ # Indexes for selection
398
+ u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
399
+ s = self.sparse_block_size
400
+
401
+ # Skip right block if causal
402
+ if self.is_causal:
403
+ return hidden_states[..., u-s:u, :]
404
+
405
+ u_ = u + odd_offset
406
+ return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
407
+
408
+ def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
409
+
410
+ n, h, b, t, d = x_local.size()
411
+ x_global = x_global.unsqueeze(-3).expand(-1, -1, b, -1, -1)
412
+ if x_sparse is not None:
413
+ return torch.cat([x_global, x_sparse, x_local], dim=dim)
414
+ return torch.cat([x_global, x_local], dim=dim)
415
+
416
+ def chunk(self, x, n_blocks):
417
+
418
+ t, d = x.size()[-2:]
419
+ return x.reshape(*x.size()[:-2], n_blocks, -1, d)
420
+
421
+
422
+ class LSGSelfAttention(BaseSelfAttention):
423
+ '''
424
+ Compute local attention with overlapping blocs
425
+ Use global attention for tokens with highest norm
426
+ '''
427
+ def __init__(self, config):
428
+ super().__init__()
429
+
430
+ self.init_modules(config)
431
+
432
+ self.block_size = config.block_size
433
+ self.sparse_block_size = config.sparse_block_size
434
+ self.num_global_tokens = config.num_global_tokens
435
+ self.sparsity_factor = config.sparsity_factor
436
+ self.is_causal = config.is_decoder
437
+ self.is_decoder = config.is_decoder
438
+
439
+ self.attention = LSGAttentionProduct(
440
+ config,
441
+ block_size=config.block_size,
442
+ sparse_block_size=config.sparse_block_size,
443
+ sparsity_factor=self.sparsity_factor,
444
+ is_causal=self.is_causal
445
+ )
446
+
447
+ if self.is_causal:
448
+ self.causal_attention = CausalAttentionProduct(config)
449
+ self.full_attention = BaseAttentionProduct(config)
450
+
451
+ sparse_functions = {
452
+ "norm": self.get_sparse_tokens_with_norm,
453
+ "pooling": self.get_sparse_tokens_with_pooling,
454
+ "lsh": self.get_sparse_tokens_with_lsh,
455
+ "stride": self.get_sparse_tokens_with_stride,
456
+ "block_stride": self.get_sparse_tokens_with_block_stride,
457
+ }
458
+
459
+ self.sparsity_type = config.sparsity_type
460
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
461
+
462
+ if config.sparsity_type == "lsh":
463
+ self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
464
+
465
+ def get_sparse_tokens_with_norm(self, keys, values, mask):
466
+
467
+ if self.sparsity_factor == 1:
468
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
469
+
470
+ with torch.no_grad():
471
+
472
+ block_size = min(self.block_size, self.sparse_block_size)
473
+ key_norm = keys.detach().norm(dim=-1, keepdim=True)
474
+ key_norm = key_norm * ~mask.transpose(-1, -2).bool()
475
+ key_norm = self.chunk(key_norm, block_size)
476
+
477
+ n, h, b, t, d = key_norm.size()
478
+
479
+ idx = key_norm.argsort(dim=-2)
480
+ del key_norm
481
+ idx += (torch.arange(b, device=keys.device)*t).reshape(1, 1, b, 1, 1)
482
+
483
+ split = (t - block_size // self.sparsity_factor, block_size // self.sparsity_factor)
484
+ sparse_idx = idx.split(split, -2)[-1].reshape(n, h, -1, 1)
485
+
486
+ d = keys.size()[-1]
487
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
488
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
489
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
490
+
491
+ return keys, values, mask
492
+
493
+ def get_sparse_tokens_with_pooling(self, keys, values, mask):
494
+
495
+ if self.sparsity_factor == 1:
496
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
497
+
498
+ keys = self.chunk(keys, self.sparsity_factor)
499
+ values = self.chunk(values, self.sparsity_factor)
500
+
501
+ n, h, b, t, d = keys.size()
502
+ mask = mask.reshape(n, 1, b, 1, t)
503
+ mask = ~mask.transpose(-1, -2).bool()
504
+
505
+ keys = keys * mask
506
+ values = values * mask
507
+
508
+ mask = mask.sum(dim=-2)
509
+ keys = keys.sum(dim=-2) / (mask + 1e-6)
510
+ values = values.sum(dim=-2) / (mask + 1e-6)
511
+
512
+ mask = - (1. - mask.clamp(0, 1)) * 1e4
513
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
514
+
515
+ def get_sparse_tokens_with_stride(self, keys, values, mask):
516
+
517
+ if self.sparsity_factor == 1:
518
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
519
+
520
+ n, h, t, d = keys.size()
521
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
522
+ sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
523
+ sparse_idx = sparse_idx.expand(n, h, -1, 1)
524
+
525
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
526
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
527
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
528
+
529
+ return keys, values, mask
530
+
531
+ def get_sparse_tokens_with_block_stride(self, keys, values, mask):
532
+
533
+ if self.sparsity_factor == 1:
534
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
535
+
536
+ n, h, t, d = keys.size()
537
+
538
+ t, b = self.block_size, t // self.block_size
539
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
540
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
541
+ sparse_idx = (sparse_idx % t)
542
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
543
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
544
+
545
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
546
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
547
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
548
+
549
+ return keys, values, mask
550
+
551
+ def get_sparse_tokens_with_lsh(self, keys, values, mask):
552
+
553
+ if self.sparsity_factor == 1:
554
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
555
+
556
+ block_size = min(self.block_size, self.sparse_block_size)
557
+ keys = self.chunk(keys, block_size)
558
+ values = self.chunk(values, block_size)
559
+
560
+ n, h, b, t, d = keys.size()
561
+ mask = mask.reshape(n, 1, b, 1, t)
562
+ mask = ~mask.transpose(-1, -2).bool()
563
+
564
+ keys = keys * mask
565
+ values = values * mask
566
+ mask = mask.expand(-1, h, -1, -1, -1).float()
567
+
568
+ extra_factor = 1
569
+
570
+ for _ in range(self.lsh_num_pre_rounds):
571
+ keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
572
+
573
+ keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
574
+ keys /= mask + 1e-8
575
+ values /= mask + 1e-8
576
+
577
+ mask = -10000 * (1. - mask.clamp(0, 1))
578
+
579
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
580
+
581
+ def lsh_round(self, keys, values, mask, output_size):
582
+
583
+ with torch.no_grad():
584
+
585
+ n_hashes = output_size // 2
586
+ n, h, b, t, d = keys.size()
587
+ binary_mask = mask.clamp(0, 1)
588
+
589
+ indexes = (torch.nn.functional.normalize(keys, dim=-1) * binary_mask) @ torch.randn(1, h, 1, d, n_hashes, device=keys.device)
590
+ indexes = torch.cat([indexes, -indexes], dim=-1).argmax(dim=-1, keepdim=True)
591
+
592
+ n, h, b, t, d = keys.size()
593
+
594
+ x_ = torch.zeros(n, h, b, output_size, d, device=keys.device)
595
+ mask_ = torch.zeros(n, h, b, output_size, 1, device=keys.device)
596
+ keys = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=keys)
597
+ values = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=values)
598
+ mask = torch.scatter_add(mask_, dim=-2, index=indexes, src=mask)
599
+
600
+ return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
601
+
602
+ def forward(
603
+ self,
604
+ hidden_states,
605
+ attention_mask=None,
606
+ head_mask=None,
607
+ encoder_hidden_states=None,
608
+ encoder_attention_mask=None,
609
+ past_key_value=None,
610
+ output_attentions=False,
611
+ ):
612
+
613
+ query_layer = self.q_lin(hidden_states)
614
+
615
+ # If this is instantiated as a cross-attention module, the keys
616
+ # and values come from an encoder; the attention mask needs to be
617
+ # such that the encoder's padding tokens are not attended to.
618
+ is_cross_attention = encoder_hidden_states is not None
619
+
620
+ if is_cross_attention and past_key_value is not None:
621
+ # reuse k,v, cross_attentions
622
+ key_layer = past_key_value[0]
623
+ value_layer = past_key_value[1]
624
+ attention_mask = encoder_attention_mask
625
+ elif is_cross_attention:
626
+ key_layer = self.transpose_for_scores(self.k_lin(encoder_hidden_states))
627
+ value_layer = self.transpose_for_scores(self.v_lin(encoder_hidden_states))
628
+ attention_mask = encoder_attention_mask
629
+ elif past_key_value is not None:
630
+ key_layer = self.transpose_for_scores(self.k_lin(hidden_states))
631
+ value_layer = self.transpose_for_scores(self.v_lin(hidden_states))
632
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
633
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
634
+ else:
635
+ key_layer = self.transpose_for_scores(self.k_lin(hidden_states))
636
+ value_layer = self.transpose_for_scores(self.v_lin(hidden_states))
637
+
638
+ query_layer = self.transpose_for_scores(query_layer)
639
+
640
+ if self.is_decoder:
641
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
642
+ # Further calls to cross_attention layer can then reuse all cross-attention
643
+ # key/value_states (first "if" case)
644
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
645
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
646
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
647
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
648
+ past_key_value = (key_layer, value_layer)
649
+
650
+ if is_cross_attention:
651
+ outputs = self.cross_attention_forward(
652
+ query_layer=query_layer,
653
+ key_layer=key_layer,
654
+ value_layer=value_layer,
655
+ attention_mask=attention_mask,
656
+ output_attentions=output_attentions
657
+ )
658
+ else:
659
+ outputs = self.causal_forward(
660
+ query_layer,
661
+ key_layer,
662
+ value_layer,
663
+ attention_mask=attention_mask,
664
+ output_attentions=output_attentions,
665
+ )
666
+
667
+ outputs = outputs + ((key_layer, value_layer),)
668
+
669
+ else:
670
+ outputs = self.not_causal_forward(
671
+ query_layer,
672
+ key_layer,
673
+ value_layer,
674
+ attention_mask=attention_mask,
675
+ output_attentions=output_attentions
676
+ )
677
+
678
+ #if head_mask is not None:
679
+ # outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
680
+ return (self.out_lin(outputs[0]),) + outputs[1:]
681
+
682
+ def causal_forward(
683
+ self,
684
+ query_layer,
685
+ key_layer,
686
+ value_layer,
687
+ attention_mask=None,
688
+ output_attentions=False,
689
+ ):
690
+
691
+ n, h, t, d = key_layer.size()
692
+
693
+ # Cat global mask
694
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
695
+
696
+ # Split input into global tokens and other tokens
697
+ split = (self.num_global_tokens, t - self.num_global_tokens)
698
+ global_query, query_layer = query_layer.split(split, dim=-2)
699
+
700
+ # Use normal causal attention if local attention covers every tokens
701
+ if t <= 2 * self.block_size + self.num_global_tokens:
702
+ context_layer = self.causal_attention(
703
+ query_layer=query_layer,
704
+ key_layer=key_layer,
705
+ value_layer=value_layer,
706
+ attention_mask=attention_mask,
707
+ causal_shape=(t - self.num_global_tokens, t - self.num_global_tokens)
708
+ )
709
+
710
+ context_layer = torch.cat([global_query, context_layer], dim=-2)
711
+ return (self.reshape_output(context_layer), )
712
+
713
+ # Split K Q M on global and non global
714
+ global_key, key_layer = key_layer.split(split, dim=-2)
715
+ global_value, value_layer = value_layer.split(split, dim=-2)
716
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
717
+
718
+ n, h, t, d = key_layer.size()
719
+
720
+ # Get sparse idx
721
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
722
+ if self.sparse_block_size and self.sparsity_factor > 0:
723
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
724
+
725
+ # Expand masks on heads
726
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
727
+ global_mask = global_mask.expand(-1, h, -1, -1)
728
+
729
+ # Compute dot product attention
730
+ context_layer = self.attention(
731
+ query_layer,
732
+ key_layer,
733
+ value_layer,
734
+ attention_mask,
735
+ sparse_key=sparse_key,
736
+ sparse_value=sparse_value,
737
+ sparse_mask=sparse_mask,
738
+ global_key=global_key,
739
+ global_value=global_value,
740
+ global_mask=global_mask
741
+ )
742
+
743
+ # Merge pseudo global (causal) and local-sparse tokens
744
+ context_layer = torch.cat([global_query, context_layer], dim=-2)
745
+ context_layer = self.reshape_output(context_layer)
746
+
747
+ return (context_layer,)
748
+
749
+ def not_causal_forward(
750
+ self,
751
+ query_layer,
752
+ key_layer,
753
+ value_layer,
754
+ attention_mask=None,
755
+ output_attentions=False,
756
+ ):
757
+
758
+ n, h, t, d = query_layer.size()
759
+
760
+ # Cat global mask
761
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
762
+
763
+ # Use normal attention if local attention covers every tokens
764
+ if t <= 2 * self.block_size + self.num_global_tokens:
765
+ context_layer = self.full_attention(
766
+ query_layer=query_layer,
767
+ key_layer=key_layer,
768
+ value_layer=value_layer,
769
+ attention_mask=attention_mask
770
+ )
771
+ return (self.reshape_output(context_layer), )
772
+
773
+ # Split input into global tokens and other tokens
774
+ split = (self.num_global_tokens, t - self.num_global_tokens)
775
+ global_query, query_layer = query_layer.split(split, dim=-2)
776
+
777
+ # Get global_attention
778
+ bos = self.full_attention(
779
+ query_layer=global_query,
780
+ key_layer=key_layer,
781
+ value_layer=value_layer,
782
+ attention_mask=attention_mask
783
+ )
784
+
785
+ # Split K Q M on global and non global
786
+ global_key, key_layer = key_layer.split(split, dim=-2)
787
+ global_value, value_layer = value_layer.split(split, dim=-2)
788
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
789
+
790
+ n, h, t, d = key_layer.size()
791
+
792
+ # Get sparse idx
793
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
794
+
795
+ if self.sparse_block_size and self.sparsity_factor > 0:
796
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
797
+
798
+ # Expand masks on heads
799
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
800
+ global_mask = global_mask.expand(-1, h, -1, -1)
801
+
802
+ # Compute dot product attention
803
+ context_layer = self.attention(
804
+ query_layer,
805
+ key_layer,
806
+ value_layer,
807
+ attention_mask,
808
+ sparse_key=sparse_key,
809
+ sparse_value=sparse_value,
810
+ sparse_mask=sparse_mask,
811
+ global_key=global_key,
812
+ global_value=global_value,
813
+ global_mask=global_mask
814
+ )
815
+
816
+ # Merge global and local-sparse tokens
817
+ context_layer = torch.cat([bos, context_layer], dim=-2)
818
+ context_layer = self.reshape_output(context_layer)
819
+
820
+ return (context_layer,)
821
+
822
+ def cross_attention_forward(
823
+ self,
824
+ query_layer,
825
+ key_layer,
826
+ value_layer,
827
+ attention_mask=None,
828
+ output_attentions=False,
829
+ ):
830
+
831
+ context_layer = self.full_attention(
832
+ query_layer=query_layer,
833
+ key_layer=key_layer,
834
+ value_layer=value_layer,
835
+ attention_mask=attention_mask
836
+ )
837
+ return (self.reshape_output(context_layer), )
838
+
839
+ def chunk(self, x, chunk_size):
840
+
841
+ n, h, t, d = x.size()
842
+ return x.reshape(n, h, -1, chunk_size, d)
843
+
844
+
845
+ class LSGTransformerBlock(nn.Module):
846
+
847
+ def __init__(self, config):
848
+
849
+ nn.Module.__init__(self)
850
+
851
+ assert config.dim % config.n_heads == 0
852
+
853
+ self.attention = LSGSelfAttention(config)
854
+ self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
855
+
856
+ self.ffn = FFN(config)
857
+ self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
858
+
859
+ def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
860
+ """
861
+ Parameters:
862
+ x: torch.tensor(bs, seq_length, dim)
863
+ attn_mask: torch.tensor(bs, seq_length)
864
+
865
+ Returns:
866
+ sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
867
+ torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
868
+ """
869
+ # Self-Attention
870
+ sa_output = self.attention(
871
+ hidden_states=x,
872
+ attention_mask=-10000*(1 - attn_mask).unsqueeze(1).unsqueeze(1),
873
+ head_mask=head_mask,
874
+ output_attentions=output_attentions,
875
+ )
876
+ if output_attentions:
877
+ sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
878
+ else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
879
+ assert type(sa_output) == tuple
880
+ sa_output = sa_output[0]
881
+ sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
882
+
883
+ # Feed Forward Network
884
+ ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
885
+ ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
886
+
887
+ output = (ffn_output,)
888
+ if output_attentions:
889
+ output = (sa_weights,) + output
890
+ return output
891
+
892
+
893
+ class LSGTransformer(Transformer):
894
+
895
+ def __init__(self, config):
896
+
897
+ nn.Module.__init__(self)
898
+
899
+ self.n_layers = config.n_layers
900
+ self.layer = nn.ModuleList([LSGTransformerBlock(config) for _ in range(config.n_layers)])
901
+
902
+
903
+ class LSGDistilBertPreTrainedModel(DistilBertPreTrainedModel):
904
+ """
905
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
906
+ models.
907
+ """
908
+
909
+ config_class = LSGDistilBertConfig
910
+
911
+
912
+ class LSGDistilBertModel(LSGDistilBertPreTrainedModel, DistilBertModel):
913
+
914
+ def __init__(self, config):
915
+
916
+ LSGDistilBertPreTrainedModel.__init__(self, config)
917
+
918
+ self.embeddings = LSGEmbeddings(config) # Embeddings
919
+ self.transformer = LSGTransformer(config) # Encoder
920
+
921
+ assert hasattr(config, "num_global_tokens")
922
+ self.num_global_tokens = config.num_global_tokens
923
+ self.pad_idx = config.pad_token_id
924
+
925
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
926
+ self.block_size = config.block_size
927
+ self.adaptive = config.adaptive
928
+ self.pool_with_global = config.pool_with_global
929
+
930
+ # Initialize weights and apply final processing
931
+ self.post_init()
932
+
933
+ def forward(
934
+ self,
935
+ input_ids=None,
936
+ attention_mask=None,
937
+ head_mask=None,
938
+ inputs_embeds=None,
939
+ output_attentions=None,
940
+ output_hidden_states=None,
941
+ return_dict=None,
942
+ ):
943
+
944
+ inputs_ = input_ids if input_ids is not None else inputs_embeds
945
+ n, t = inputs_.size()[:2]
946
+
947
+ if attention_mask is None:
948
+ attention_mask = torch.ones(n, t, device=inputs_.device)
949
+
950
+ b = self.block_size * 2
951
+ pad = t % self.block_size
952
+
953
+ # Check if t is multiple of block_size and pad
954
+ if self.adaptive and t > b and pad > 0:
955
+ pad_length = self.block_size - pad
956
+ if input_ids is not None:
957
+ input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
958
+ else:
959
+ inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
960
+
961
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
962
+
963
+ n, t_ = attention_mask.size()
964
+
965
+ encoder_outputs = self._forward(
966
+ input_ids=input_ids,
967
+ attention_mask=attention_mask,
968
+ head_mask=head_mask,
969
+ inputs_embeds=inputs_embeds,
970
+ output_attentions=output_attentions,
971
+ output_hidden_states=output_hidden_states,
972
+ return_dict=return_dict,
973
+ )
974
+
975
+ context = encoder_outputs[0]
976
+ if self.pool_with_global:
977
+ context[:, self.num_global_tokens] = context[:, 0]
978
+
979
+ diff = t - t_
980
+ n, _, d = context.size()
981
+ context = context[..., self.num_global_tokens:, :]
982
+
983
+ # Adapt sequence to initial shape
984
+ if diff < 0:
985
+ context = context[:, :t]
986
+
987
+ if not return_dict:
988
+ return (context, ) + encoder_outputs[1:]
989
+
990
+ return BaseModelOutput(
991
+ last_hidden_state=context,
992
+ hidden_states=encoder_outputs.hidden_states,
993
+ attentions=encoder_outputs.attentions,
994
+ )
995
+
996
+ def _forward(
997
+ self,
998
+ input_ids=None,
999
+ attention_mask=None,
1000
+ head_mask=None,
1001
+ inputs_embeds=None,
1002
+ output_attentions=None,
1003
+ output_hidden_states=None,
1004
+ return_dict=None,
1005
+ ):
1006
+
1007
+ # Prepare head mask if needed
1008
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1009
+ inputs_embeds = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
1010
+ return self.transformer(
1011
+ x=inputs_embeds,
1012
+ attn_mask=attention_mask,
1013
+ head_mask=head_mask,
1014
+ output_attentions=output_attentions,
1015
+ output_hidden_states=output_hidden_states,
1016
+ return_dict=return_dict,
1017
+ )
1018
+
1019
+
1020
+ class LSGDistilBertForMaskedLM(LSGDistilBertPreTrainedModel, DistilBertForMaskedLM):
1021
+
1022
+ def __init__(self, config):
1023
+
1024
+ LSGDistilBertPreTrainedModel.__init__(self, config)
1025
+
1026
+ self.activation = get_activation(config.activation)
1027
+
1028
+ self.distilbert = LSGDistilBertModel(config)
1029
+ self.vocab_transform = nn.Linear(config.dim, config.dim)
1030
+ self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
1031
+ self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
1032
+
1033
+ # Initialize weights and apply final processing
1034
+ self.post_init()
1035
+
1036
+ self.mlm_loss_fct = nn.CrossEntropyLoss()
1037
+
1038
+
1039
+ class LSGDistilBertForSequenceClassification(LSGDistilBertPreTrainedModel, DistilBertForSequenceClassification):
1040
+
1041
+ def __init__(self, config):
1042
+
1043
+ LSGDistilBertPreTrainedModel.__init__(self, config)
1044
+
1045
+ self.num_labels = config.num_labels
1046
+ self.config = config
1047
+
1048
+ self.distilbert = LSGDistilBertModel(config)
1049
+ self.pre_classifier = nn.Linear(config.dim, config.dim)
1050
+ self.classifier = nn.Linear(config.dim, config.num_labels)
1051
+ self.dropout = nn.Dropout(config.seq_classif_dropout)
1052
+
1053
+ # Initialize weights and apply final processing
1054
+ self.post_init()
1055
+
1056
+
1057
+ class LSGDistilBertForQuestionAnswering(LSGDistilBertPreTrainedModel, DistilBertForQuestionAnswering):
1058
+
1059
+ def __init__(self, config):
1060
+
1061
+ LSGDistilBertPreTrainedModel.__init__(self, config)
1062
+
1063
+ self.distilbert = LSGDistilBertModel(config)
1064
+ self.qa_outputs = nn.Linear(config.dim, config.num_labels)
1065
+ assert config.num_labels == 2
1066
+ self.dropout = nn.Dropout(config.qa_dropout)
1067
+
1068
+ # Initialize weights and apply final processing
1069
+ self.post_init()
1070
+
1071
+
1072
+ class LSGDistilBertForTokenClassification(LSGDistilBertPreTrainedModel, DistilBertForTokenClassification):
1073
+
1074
+ def __init__(self, config):
1075
+
1076
+ LSGDistilBertPreTrainedModel.__init__(self, config)
1077
+
1078
+ self.num_labels = config.num_labels
1079
+
1080
+ self.distilbert = LSGDistilBertModel(config)
1081
+ self.dropout = nn.Dropout(config.dropout)
1082
+ self.classifier = nn.Linear(config.dim, config.num_labels)
1083
+
1084
+ # Initialize weights and apply final processing
1085
+ self.post_init()
1086
+
1087
+
1088
+ class LSGDistilBertForMultipleChoice(LSGDistilBertPreTrainedModel, DistilBertForMultipleChoice):
1089
+
1090
+ def __init__(self, config):
1091
+
1092
+ LSGDistilBertPreTrainedModel.__init__(self, config)
1093
+
1094
+ self.distilbert = LSGDistilBertModel(config)
1095
+ self.pre_classifier = nn.Linear(config.dim, config.dim)
1096
+ self.classifier = nn.Linear(config.dim, 1)
1097
+ self.dropout = nn.Dropout(config.seq_classif_dropout)
1098
+
1099
+ # Initialize weights and apply final processing
1100
+ self.post_init()
1101
+
1102
+
1103
+ def str_to_class(classname):
1104
+ return getattr(sys.modules[__name__], classname)
1105
+
1106
+ # Register model in Auto API
1107
+ try:
1108
+ LSGDistilBertConfig.register_for_auto_class()
1109
+ for key, value in AUTO_MAP.items():
1110
+ str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1111
+ except:
1112
+ warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1113
+ warn("Update to transformers >= 4.17.0 to fix.")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:000339bb8ceb93f662bf72fc5c4232a62ac88f2ab7ec28b72a728078a7531ba3
3
+ size 282131437
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 4096, "special_tokens_map_file": null, "name_or_path": "distilbert-base-uncased", "tokenizer_class": "DistilBertTokenizer"}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff