bos_token + readme
Browse files- README.md +12 -7
- modeling_lsg_bert.py +39 -12
README.md
CHANGED
|
@@ -69,26 +69,31 @@ model = AutoModel.from_pretrained("ccdv/lsg-bert-base-uncased-4096",
|
|
| 69 |
|
| 70 |
## Sparse selection type
|
| 71 |
|
| 72 |
-
There are
|
|
|
|
| 73 |
Note that for sequences with length < 2*block_size, the type has no effect.
|
| 74 |
-
|
| 75 |
-
*
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
* Works best for a small sparsity_factor (2 to 4)
|
| 77 |
* Additional parameters:
|
| 78 |
* None
|
| 79 |
-
* sparsity_type="pooling"
|
| 80 |
* Works best for a small sparsity_factor (2 to 4)
|
| 81 |
* Additional parameters:
|
| 82 |
* None
|
| 83 |
-
* sparsity_type="lsh"
|
| 84 |
* Works best for a large sparsity_factor (4+)
|
| 85 |
* LSH relies on random projections, thus inference may differ slightly with different seeds
|
| 86 |
* Additional parameters:
|
| 87 |
* lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
|
| 88 |
-
* sparsity_type="stride"
|
| 89 |
* Each head will use different tokens strided by sparsify_factor
|
| 90 |
* Not recommended if sparsify_factor > num_heads
|
| 91 |
-
* sparsity_type="block_stride"
|
| 92 |
* Each head will use block of tokens strided by sparsify_factor
|
| 93 |
* Not recommended if sparsify_factor > num_heads
|
| 94 |
|
|
|
|
| 69 |
|
| 70 |
## Sparse selection type
|
| 71 |
|
| 72 |
+
There are 6 different sparse selection patterns. The best type is task dependent. \
|
| 73 |
+
If `sparse_block_size=0` or `sparsity_type="none"`, only local attention is considered. \
|
| 74 |
Note that for sequences with length < 2*block_size, the type has no effect.
|
| 75 |
+
* `sparsity_type="bos_pooling"` (new)
|
| 76 |
+
* weighted average pooling using the BOS token
|
| 77 |
+
* Works best in general, especially with a rather large sparsity_factor (8, 16, 32)
|
| 78 |
+
* Additional parameters:
|
| 79 |
+
* None
|
| 80 |
+
* `sparsity_type="norm"`, select highest norm tokens
|
| 81 |
* Works best for a small sparsity_factor (2 to 4)
|
| 82 |
* Additional parameters:
|
| 83 |
* None
|
| 84 |
+
* `sparsity_type="pooling"`, use average pooling to merge tokens
|
| 85 |
* Works best for a small sparsity_factor (2 to 4)
|
| 86 |
* Additional parameters:
|
| 87 |
* None
|
| 88 |
+
* `sparsity_type="lsh"`, use the LSH algorithm to cluster similar tokens
|
| 89 |
* Works best for a large sparsity_factor (4+)
|
| 90 |
* LSH relies on random projections, thus inference may differ slightly with different seeds
|
| 91 |
* Additional parameters:
|
| 92 |
* lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
|
| 93 |
+
* `sparsity_type="stride"`, use a striding mecanism per head
|
| 94 |
* Each head will use different tokens strided by sparsify_factor
|
| 95 |
* Not recommended if sparsify_factor > num_heads
|
| 96 |
+
* `sparsity_type="block_stride"`, use a striding mecanism per head
|
| 97 |
* Each head will use block of tokens strided by sparsify_factor
|
| 98 |
* Not recommended if sparsify_factor > num_heads
|
| 99 |
|
modeling_lsg_bert.py
CHANGED
|
@@ -54,16 +54,16 @@ class LSGBertConfig(BertConfig):
|
|
| 54 |
self.sparsity_factor = sparsity_factor
|
| 55 |
self.sparsity_type = sparsity_type
|
| 56 |
|
| 57 |
-
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
| 58 |
logger.warning(
|
| 59 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
| 60 |
setting sparsity_type=None, computation will skip sparse attention")
|
| 61 |
self.sparsity_type = None
|
| 62 |
|
| 63 |
if self.sparsity_type in ["stride", "block_stride"]:
|
| 64 |
-
if self.sparsity_factor > self.
|
| 65 |
logger.warning(
|
| 66 |
-
"[WARNING CONFIG]: sparsity_factor >
|
| 67 |
)
|
| 68 |
|
| 69 |
if self.num_global_tokens < 1:
|
|
@@ -491,15 +491,16 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 491 |
"lsh": self.get_sparse_tokens_with_lsh,
|
| 492 |
"stride": self.get_sparse_tokens_with_stride,
|
| 493 |
"block_stride": self.get_sparse_tokens_with_block_stride,
|
|
|
|
| 494 |
}
|
| 495 |
|
| 496 |
self.sparsity_type = config.sparsity_type
|
| 497 |
-
self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
|
| 498 |
|
| 499 |
if config.sparsity_type == "lsh":
|
| 500 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
| 501 |
|
| 502 |
-
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
| 503 |
|
| 504 |
if self.sparsity_factor == 1:
|
| 505 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
@@ -527,7 +528,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 527 |
|
| 528 |
return keys, values, mask
|
| 529 |
|
| 530 |
-
def get_sparse_tokens_with_pooling(self, keys, values, mask):
|
| 531 |
|
| 532 |
if self.sparsity_factor == 1:
|
| 533 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
@@ -550,7 +551,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 550 |
mask *= torch.finfo(mask.dtype).min
|
| 551 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 552 |
|
| 553 |
-
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
| 554 |
|
| 555 |
if self.sparsity_factor == 1:
|
| 556 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
@@ -566,7 +567,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 566 |
|
| 567 |
return keys, values, mask
|
| 568 |
|
| 569 |
-
def get_sparse_tokens_with_block_stride(self, keys, values, mask):
|
| 570 |
|
| 571 |
if self.sparsity_factor == 1:
|
| 572 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
@@ -586,10 +587,13 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 586 |
|
| 587 |
return keys, values, mask
|
| 588 |
|
| 589 |
-
def get_sparse_tokens_with_lsh(self, keys, values, mask):
|
| 590 |
|
| 591 |
if self.sparsity_factor == 1:
|
| 592 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
|
|
|
|
|
|
|
| 593 |
|
| 594 |
block_size = min(self.block_size, self.sparse_block_size)
|
| 595 |
keys = self.chunk(keys, block_size)
|
|
@@ -638,6 +642,29 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 638 |
|
| 639 |
return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
|
| 640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
def forward(
|
| 642 |
self,
|
| 643 |
hidden_states,
|
|
@@ -757,7 +784,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 757 |
# Get sparse idx
|
| 758 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
| 759 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
| 760 |
-
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
|
| 761 |
|
| 762 |
# Expand masks on heads
|
| 763 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|
|
@@ -830,7 +857,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
| 830 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
| 831 |
|
| 832 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
| 833 |
-
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
|
| 834 |
|
| 835 |
# Expand masks on heads
|
| 836 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|
|
|
|
| 54 |
self.sparsity_factor = sparsity_factor
|
| 55 |
self.sparsity_type = sparsity_type
|
| 56 |
|
| 57 |
+
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
|
| 58 |
logger.warning(
|
| 59 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
|
| 60 |
setting sparsity_type=None, computation will skip sparse attention")
|
| 61 |
self.sparsity_type = None
|
| 62 |
|
| 63 |
if self.sparsity_type in ["stride", "block_stride"]:
|
| 64 |
+
if self.sparsity_factor > self.num_attention_heads:
|
| 65 |
logger.warning(
|
| 66 |
+
"[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity"
|
| 67 |
)
|
| 68 |
|
| 69 |
if self.num_global_tokens < 1:
|
|
|
|
| 491 |
"lsh": self.get_sparse_tokens_with_lsh,
|
| 492 |
"stride": self.get_sparse_tokens_with_stride,
|
| 493 |
"block_stride": self.get_sparse_tokens_with_block_stride,
|
| 494 |
+
"bos_pooling": self.get_sparse_tokens_with_bos_pooling
|
| 495 |
}
|
| 496 |
|
| 497 |
self.sparsity_type = config.sparsity_type
|
| 498 |
+
self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
|
| 499 |
|
| 500 |
if config.sparsity_type == "lsh":
|
| 501 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
| 502 |
|
| 503 |
+
def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
|
| 504 |
|
| 505 |
if self.sparsity_factor == 1:
|
| 506 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
|
| 528 |
|
| 529 |
return keys, values, mask
|
| 530 |
|
| 531 |
+
def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
|
| 532 |
|
| 533 |
if self.sparsity_factor == 1:
|
| 534 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
|
| 551 |
mask *= torch.finfo(mask.dtype).min
|
| 552 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 553 |
|
| 554 |
+
def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
|
| 555 |
|
| 556 |
if self.sparsity_factor == 1:
|
| 557 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
|
| 567 |
|
| 568 |
return keys, values, mask
|
| 569 |
|
| 570 |
+
def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
|
| 571 |
|
| 572 |
if self.sparsity_factor == 1:
|
| 573 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
|
| 587 |
|
| 588 |
return keys, values, mask
|
| 589 |
|
| 590 |
+
def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
|
| 591 |
|
| 592 |
if self.sparsity_factor == 1:
|
| 593 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
| 594 |
+
|
| 595 |
+
if self.sparsity_factor == self.sparse_block_size:
|
| 596 |
+
return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
|
| 597 |
|
| 598 |
block_size = min(self.block_size, self.sparse_block_size)
|
| 599 |
keys = self.chunk(keys, block_size)
|
|
|
|
| 642 |
|
| 643 |
return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
|
| 644 |
|
| 645 |
+
def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
|
| 646 |
+
|
| 647 |
+
if self.sparsity_factor == 1:
|
| 648 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
| 649 |
+
|
| 650 |
+
queries = queries.unsqueeze(-3)
|
| 651 |
+
mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
|
| 652 |
+
keys = self.chunk(keys, self.sparsity_factor)
|
| 653 |
+
values = self.chunk(values, self.sparsity_factor)
|
| 654 |
+
|
| 655 |
+
n, h, b, t, d = keys.size()
|
| 656 |
+
scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
|
| 657 |
+
if mask is not None:
|
| 658 |
+
scores = scores + mask
|
| 659 |
+
|
| 660 |
+
scores = torch.softmax(scores, dim=-1)
|
| 661 |
+
keys = scores @ keys
|
| 662 |
+
values = scores @ values
|
| 663 |
+
mask = mask.mean(dim=-1)
|
| 664 |
+
mask[mask != torch.finfo(mask.dtype).min] = 0
|
| 665 |
+
|
| 666 |
+
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 667 |
+
|
| 668 |
def forward(
|
| 669 |
self,
|
| 670 |
hidden_states,
|
|
|
|
| 784 |
# Get sparse idx
|
| 785 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
| 786 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
| 787 |
+
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
|
| 788 |
|
| 789 |
# Expand masks on heads
|
| 790 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|
|
|
|
| 857 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
| 858 |
|
| 859 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
| 860 |
+
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
|
| 861 |
|
| 862 |
# Expand masks on heads
|
| 863 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|