Commit
·
b4f2b16
1
Parent(s):
e36c994
rename to jina bert
Browse files- modeling_bert.py +75 -129
modeling_bert.py
CHANGED
|
@@ -54,7 +54,7 @@ from transformers.utils import (
|
|
| 54 |
logging,
|
| 55 |
replace_return_docstrings,
|
| 56 |
)
|
| 57 |
-
from .configuration_bert import
|
| 58 |
|
| 59 |
try:
|
| 60 |
from tqdm.autonotebook import trange
|
|
@@ -66,7 +66,7 @@ except ImportError:
|
|
| 66 |
logger = logging.get_logger(__name__)
|
| 67 |
|
| 68 |
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
| 69 |
-
_CONFIG_FOR_DOC = "
|
| 70 |
|
| 71 |
# TokenClassification docstring
|
| 72 |
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = (
|
|
@@ -197,10 +197,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|
| 197 |
return model
|
| 198 |
|
| 199 |
|
| 200 |
-
class
|
| 201 |
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 202 |
|
| 203 |
-
def __init__(self, config:
|
| 204 |
super().__init__()
|
| 205 |
self.word_embeddings = nn.Embedding(
|
| 206 |
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
|
@@ -280,7 +280,7 @@ class MyBertEmbeddings(nn.Module):
|
|
| 280 |
return embeddings
|
| 281 |
|
| 282 |
|
| 283 |
-
class
|
| 284 |
def __init__(self, config, position_embedding_type=None):
|
| 285 |
super().__init__()
|
| 286 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
|
@@ -448,7 +448,7 @@ class MyBertSelfAttention(nn.Module):
|
|
| 448 |
return outputs
|
| 449 |
|
| 450 |
|
| 451 |
-
class
|
| 452 |
def __init__(self, config):
|
| 453 |
super().__init__()
|
| 454 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
@@ -464,13 +464,13 @@ class MyBertSelfOutput(nn.Module):
|
|
| 464 |
return hidden_states
|
| 465 |
|
| 466 |
|
| 467 |
-
class
|
| 468 |
def __init__(self, config, position_embedding_type=None):
|
| 469 |
super().__init__()
|
| 470 |
-
self.self =
|
| 471 |
config, position_embedding_type=position_embedding_type
|
| 472 |
)
|
| 473 |
-
self.output =
|
| 474 |
self.pruned_heads = set()
|
| 475 |
|
| 476 |
def prune_heads(self, heads):
|
|
@@ -524,7 +524,7 @@ class MyBertAttention(nn.Module):
|
|
| 524 |
return outputs
|
| 525 |
|
| 526 |
|
| 527 |
-
class
|
| 528 |
def __init__(self, config):
|
| 529 |
super().__init__()
|
| 530 |
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
@@ -539,8 +539,8 @@ class MyBertIntermediate(nn.Module):
|
|
| 539 |
return hidden_states
|
| 540 |
|
| 541 |
|
| 542 |
-
class
|
| 543 |
-
def __init__(self, config:
|
| 544 |
super().__init__()
|
| 545 |
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 546 |
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
@@ -555,8 +555,8 @@ class MyBertOutput(nn.Module):
|
|
| 555 |
return hidden_states
|
| 556 |
|
| 557 |
|
| 558 |
-
class
|
| 559 |
-
def __init__(self, config:
|
| 560 |
super().__init__()
|
| 561 |
self.config = config
|
| 562 |
self.gated_layers = nn.Linear(
|
|
@@ -589,12 +589,12 @@ class MyBertGLUMLP(nn.Module):
|
|
| 589 |
return hidden_states
|
| 590 |
|
| 591 |
|
| 592 |
-
class
|
| 593 |
-
def __init__(self, config:
|
| 594 |
super().__init__()
|
| 595 |
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 596 |
self.seq_len_dim = 1
|
| 597 |
-
self.attention =
|
| 598 |
self.is_decoder = config.is_decoder
|
| 599 |
self.add_cross_attention = config.add_cross_attention
|
| 600 |
self.feed_forward_type = config.feed_forward_type
|
|
@@ -603,14 +603,14 @@ class MyBertLayer(nn.Module):
|
|
| 603 |
raise ValueError(
|
| 604 |
f"{self} should be used as a decoder model if cross attention is added"
|
| 605 |
)
|
| 606 |
-
self.crossattention =
|
| 607 |
config, position_embedding_type="absolute"
|
| 608 |
)
|
| 609 |
if self.feed_forward_type.endswith('glu'):
|
| 610 |
-
self.mlp =
|
| 611 |
else:
|
| 612 |
-
self.intermediate =
|
| 613 |
-
self.output =
|
| 614 |
|
| 615 |
def forward(
|
| 616 |
self,
|
|
@@ -699,12 +699,12 @@ class MyBertLayer(nn.Module):
|
|
| 699 |
return layer_output
|
| 700 |
|
| 701 |
|
| 702 |
-
class
|
| 703 |
-
def __init__(self, config:
|
| 704 |
super().__init__()
|
| 705 |
self.config = config
|
| 706 |
self.layer = nn.ModuleList(
|
| 707 |
-
[
|
| 708 |
)
|
| 709 |
self.gradient_checkpointing = False
|
| 710 |
self.num_attention_heads = config.num_attention_heads
|
|
@@ -724,26 +724,6 @@ class MyBertEncoder(nn.Module):
|
|
| 724 |
# will be applied, it is necessary to construct the diagonal mask.
|
| 725 |
n_heads = self.num_attention_heads
|
| 726 |
|
| 727 |
-
# Mosaics one
|
| 728 |
-
# def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
| 729 |
-
# def get_slopes_power_of_2(n_heads: int) -> List[float]:
|
| 730 |
-
# start = 2 ** (-(2 ** -(math.log2(n_heads) - 3)))
|
| 731 |
-
# ratio = start
|
| 732 |
-
# return [start * ratio**i for i in range(n_heads)]
|
| 733 |
-
|
| 734 |
-
# # In the paper, they only train models that have 2^a heads for some a. This function
|
| 735 |
-
# # has some good properties that only occur when the input is a power of 2. To
|
| 736 |
-
# # maintain that even when the number of heads is not a power of 2, we use a
|
| 737 |
-
# # workaround.
|
| 738 |
-
# if math.log2(n_heads).is_integer():
|
| 739 |
-
# return get_slopes_power_of_2(n_heads)
|
| 740 |
-
|
| 741 |
-
# closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
|
| 742 |
-
# slopes_a = get_slopes_power_of_2(closest_power_of_2)
|
| 743 |
-
# slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
|
| 744 |
-
# slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2]
|
| 745 |
-
# return slopes_a + slopes_b
|
| 746 |
-
|
| 747 |
def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
| 748 |
def get_slopes_power_of_2(n):
|
| 749 |
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
@@ -893,7 +873,7 @@ class MyBertEncoder(nn.Module):
|
|
| 893 |
)
|
| 894 |
|
| 895 |
|
| 896 |
-
class
|
| 897 |
def __init__(self, config):
|
| 898 |
super().__init__()
|
| 899 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
@@ -908,7 +888,7 @@ class MyBertPooler(nn.Module):
|
|
| 908 |
return pooled_output
|
| 909 |
|
| 910 |
|
| 911 |
-
class
|
| 912 |
def __init__(self, config):
|
| 913 |
super().__init__()
|
| 914 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
@@ -925,10 +905,10 @@ class MyBertPredictionHeadTransform(nn.Module):
|
|
| 925 |
return hidden_states
|
| 926 |
|
| 927 |
|
| 928 |
-
class
|
| 929 |
def __init__(self, config):
|
| 930 |
super().__init__()
|
| 931 |
-
self.transform =
|
| 932 |
|
| 933 |
# The output weights are the same as the input embeddings, but there is
|
| 934 |
# an output-only bias for each token.
|
|
@@ -945,17 +925,17 @@ class MyBertLMPredictionHead(nn.Module):
|
|
| 945 |
return hidden_states
|
| 946 |
|
| 947 |
|
| 948 |
-
class
|
| 949 |
def __init__(self, config):
|
| 950 |
super().__init__()
|
| 951 |
-
self.predictions =
|
| 952 |
|
| 953 |
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
| 954 |
prediction_scores = self.predictions(sequence_output)
|
| 955 |
return prediction_scores
|
| 956 |
|
| 957 |
|
| 958 |
-
class
|
| 959 |
def __init__(self, config):
|
| 960 |
super().__init__()
|
| 961 |
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
|
@@ -965,10 +945,10 @@ class MyBertOnlyNSPHead(nn.Module):
|
|
| 965 |
return seq_relationship_score
|
| 966 |
|
| 967 |
|
| 968 |
-
class
|
| 969 |
def __init__(self, config):
|
| 970 |
super().__init__()
|
| 971 |
-
self.predictions =
|
| 972 |
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 973 |
|
| 974 |
def forward(self, sequence_output, pooled_output):
|
|
@@ -977,13 +957,13 @@ class MyBertPreTrainingHeads(nn.Module):
|
|
| 977 |
return prediction_scores, seq_relationship_score
|
| 978 |
|
| 979 |
|
| 980 |
-
class
|
| 981 |
"""
|
| 982 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 983 |
models.
|
| 984 |
"""
|
| 985 |
|
| 986 |
-
config_class =
|
| 987 |
load_tf_weights = load_tf_weights_in_bert
|
| 988 |
base_model_prefix = "bert"
|
| 989 |
supports_gradient_checkpointing = True
|
|
@@ -1005,12 +985,12 @@ class MyBertPreTrainedModel(PreTrainedModel):
|
|
| 1005 |
module.weight.data.fill_(1.0)
|
| 1006 |
|
| 1007 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 1008 |
-
if isinstance(module,
|
| 1009 |
module.gradient_checkpointing = value
|
| 1010 |
|
| 1011 |
|
| 1012 |
@dataclass
|
| 1013 |
-
class
|
| 1014 |
"""
|
| 1015 |
Output type of [`BertForPreTraining`].
|
| 1016 |
|
|
@@ -1113,7 +1093,7 @@ BERT_INPUTS_DOCSTRING = r"""
|
|
| 1113 |
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1114 |
BERT_START_DOCSTRING,
|
| 1115 |
)
|
| 1116 |
-
class
|
| 1117 |
"""
|
| 1118 |
|
| 1119 |
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
|
@@ -1126,7 +1106,7 @@ class MyBertModel(MyBertPreTrainedModel):
|
|
| 1126 |
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 1127 |
"""
|
| 1128 |
|
| 1129 |
-
def __init__(self, config:
|
| 1130 |
super().__init__(config)
|
| 1131 |
self.config = config
|
| 1132 |
|
|
@@ -1137,17 +1117,17 @@ class MyBertModel(MyBertPreTrainedModel):
|
|
| 1137 |
|
| 1138 |
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
| 1139 |
|
| 1140 |
-
self.embeddings =
|
| 1141 |
-
self.encoder =
|
| 1142 |
|
| 1143 |
-
self.pooler =
|
| 1144 |
|
| 1145 |
# Initialize weights and apply final processing
|
| 1146 |
self.post_init()
|
| 1147 |
|
| 1148 |
@torch.inference_mode()
|
| 1149 |
def encode(
|
| 1150 |
-
self: '
|
| 1151 |
sentences: Union[str, List[str]],
|
| 1152 |
batch_size: int = 32,
|
| 1153 |
show_progress_bar: Optional[bool] = None,
|
|
@@ -1479,14 +1459,14 @@ class MyBertModel(MyBertPreTrainedModel):
|
|
| 1479 |
""",
|
| 1480 |
BERT_START_DOCSTRING,
|
| 1481 |
)
|
| 1482 |
-
class
|
| 1483 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1484 |
|
| 1485 |
def __init__(self, config):
|
| 1486 |
super().__init__(config)
|
| 1487 |
|
| 1488 |
-
self.bert =
|
| 1489 |
-
self.cls =
|
| 1490 |
|
| 1491 |
# Initialize weights and apply final processing
|
| 1492 |
self.post_init()
|
|
@@ -1501,7 +1481,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
| 1501 |
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
| 1502 |
)
|
| 1503 |
@replace_return_docstrings(
|
| 1504 |
-
output_type=
|
| 1505 |
)
|
| 1506 |
def forward(
|
| 1507 |
self,
|
|
@@ -1516,7 +1496,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
| 1516 |
output_attentions: Optional[bool] = None,
|
| 1517 |
output_hidden_states: Optional[bool] = None,
|
| 1518 |
return_dict: Optional[bool] = None,
|
| 1519 |
-
) -> Union[Tuple[torch.Tensor],
|
| 1520 |
r"""
|
| 1521 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1522 |
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
|
@@ -1532,22 +1512,6 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
| 1532 |
Used to hide legacy arguments that have been deprecated.
|
| 1533 |
|
| 1534 |
Returns:
|
| 1535 |
-
|
| 1536 |
-
Example:
|
| 1537 |
-
|
| 1538 |
-
```python
|
| 1539 |
-
>>> from transformers import AutoTokenizer, MyBertForPreTraining
|
| 1540 |
-
>>> import torch
|
| 1541 |
-
|
| 1542 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 1543 |
-
>>> model = MyBertForPreTraining.from_pretrained("bert-base-uncased")
|
| 1544 |
-
|
| 1545 |
-
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1546 |
-
>>> outputs = model(**inputs)
|
| 1547 |
-
|
| 1548 |
-
>>> prediction_logits = outputs.prediction_logits
|
| 1549 |
-
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
| 1550 |
-
```
|
| 1551 |
"""
|
| 1552 |
return_dict = (
|
| 1553 |
return_dict if return_dict is not None else self.config.use_return_dict
|
|
@@ -1585,7 +1549,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
| 1585 |
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
| 1586 |
return ((total_loss,) + output) if total_loss is not None else output
|
| 1587 |
|
| 1588 |
-
return
|
| 1589 |
loss=total_loss,
|
| 1590 |
prediction_logits=prediction_scores,
|
| 1591 |
seq_relationship_logits=seq_relationship_score,
|
|
@@ -1595,10 +1559,10 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
| 1595 |
|
| 1596 |
|
| 1597 |
@add_start_docstrings(
|
| 1598 |
-
"""
|
| 1599 |
BERT_START_DOCSTRING,
|
| 1600 |
)
|
| 1601 |
-
class
|
| 1602 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1603 |
|
| 1604 |
def __init__(self, config):
|
|
@@ -1606,11 +1570,11 @@ class MyBertLMHeadModel(MyBertPreTrainedModel):
|
|
| 1606 |
|
| 1607 |
if not config.is_decoder:
|
| 1608 |
logger.warning(
|
| 1609 |
-
"If you want to use `
|
| 1610 |
)
|
| 1611 |
|
| 1612 |
-
self.bert =
|
| 1613 |
-
self.cls =
|
| 1614 |
|
| 1615 |
# Initialize weights and apply final processing
|
| 1616 |
self.post_init()
|
|
@@ -1755,9 +1719,9 @@ class MyBertLMHeadModel(MyBertPreTrainedModel):
|
|
| 1755 |
|
| 1756 |
|
| 1757 |
@add_start_docstrings(
|
| 1758 |
-
"""
|
| 1759 |
)
|
| 1760 |
-
class
|
| 1761 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1762 |
|
| 1763 |
def __init__(self, config):
|
|
@@ -1765,12 +1729,12 @@ class MyBertForMaskedLM(MyBertPreTrainedModel):
|
|
| 1765 |
|
| 1766 |
if config.is_decoder:
|
| 1767 |
logger.warning(
|
| 1768 |
-
"If you want to use `
|
| 1769 |
"bi-directional self-attention."
|
| 1770 |
)
|
| 1771 |
|
| 1772 |
-
self.bert =
|
| 1773 |
-
self.cls =
|
| 1774 |
|
| 1775 |
# Initialize weights and apply final processing
|
| 1776 |
self.post_init()
|
|
@@ -1880,15 +1844,15 @@ class MyBertForMaskedLM(MyBertPreTrainedModel):
|
|
| 1880 |
|
| 1881 |
|
| 1882 |
@add_start_docstrings(
|
| 1883 |
-
"""
|
| 1884 |
BERT_START_DOCSTRING,
|
| 1885 |
)
|
| 1886 |
-
class
|
| 1887 |
def __init__(self, config):
|
| 1888 |
super().__init__(config)
|
| 1889 |
|
| 1890 |
-
self.bert =
|
| 1891 |
-
self.cls =
|
| 1892 |
|
| 1893 |
# Initialize weights and apply final processing
|
| 1894 |
self.post_init()
|
|
@@ -1922,24 +1886,6 @@ class MyBertForNextSentencePrediction(MyBertPreTrainedModel):
|
|
| 1922 |
- 1 indicates sequence B is a random sequence.
|
| 1923 |
|
| 1924 |
Returns:
|
| 1925 |
-
|
| 1926 |
-
Example:
|
| 1927 |
-
|
| 1928 |
-
```python
|
| 1929 |
-
>>> from transformers import AutoTokenizer, MyBertForNextSentencePrediction
|
| 1930 |
-
>>> import torch
|
| 1931 |
-
|
| 1932 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 1933 |
-
>>> model = MyBertForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
| 1934 |
-
|
| 1935 |
-
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1936 |
-
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
| 1937 |
-
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
| 1938 |
-
|
| 1939 |
-
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
| 1940 |
-
>>> logits = outputs.logits
|
| 1941 |
-
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
| 1942 |
-
```
|
| 1943 |
"""
|
| 1944 |
|
| 1945 |
if "next_sentence_label" in kwargs:
|
|
@@ -1995,18 +1941,18 @@ class MyBertForNextSentencePrediction(MyBertPreTrainedModel):
|
|
| 1995 |
|
| 1996 |
@add_start_docstrings(
|
| 1997 |
"""
|
| 1998 |
-
|
| 1999 |
output) e.g. for GLUE tasks.
|
| 2000 |
""",
|
| 2001 |
BERT_START_DOCSTRING,
|
| 2002 |
)
|
| 2003 |
-
class
|
| 2004 |
def __init__(self, config):
|
| 2005 |
super().__init__(config)
|
| 2006 |
self.num_labels = config.num_labels
|
| 2007 |
self.config = config
|
| 2008 |
|
| 2009 |
-
self.bert =
|
| 2010 |
classifier_dropout = (
|
| 2011 |
config.classifier_dropout
|
| 2012 |
if config.classifier_dropout is not None
|
|
@@ -2106,16 +2052,16 @@ class MyBertForSequenceClassification(MyBertPreTrainedModel):
|
|
| 2106 |
|
| 2107 |
@add_start_docstrings(
|
| 2108 |
"""
|
| 2109 |
-
|
| 2110 |
softmax) e.g. for RocStories/SWAG tasks.
|
| 2111 |
""",
|
| 2112 |
BERT_START_DOCSTRING,
|
| 2113 |
)
|
| 2114 |
-
class
|
| 2115 |
def __init__(self, config):
|
| 2116 |
super().__init__(config)
|
| 2117 |
|
| 2118 |
-
self.bert =
|
| 2119 |
classifier_dropout = (
|
| 2120 |
config.classifier_dropout
|
| 2121 |
if config.classifier_dropout is not None
|
|
@@ -2222,17 +2168,17 @@ class MyBertForMultipleChoice(MyBertPreTrainedModel):
|
|
| 2222 |
|
| 2223 |
@add_start_docstrings(
|
| 2224 |
"""
|
| 2225 |
-
|
| 2226 |
Named-Entity-Recognition (NER) tasks.
|
| 2227 |
""",
|
| 2228 |
BERT_START_DOCSTRING,
|
| 2229 |
)
|
| 2230 |
-
class
|
| 2231 |
def __init__(self, config):
|
| 2232 |
super().__init__(config)
|
| 2233 |
self.num_labels = config.num_labels
|
| 2234 |
|
| 2235 |
-
self.bert =
|
| 2236 |
classifier_dropout = (
|
| 2237 |
config.classifier_dropout
|
| 2238 |
if config.classifier_dropout is not None
|
|
@@ -2311,17 +2257,17 @@ class MyBertForTokenClassification(MyBertPreTrainedModel):
|
|
| 2311 |
|
| 2312 |
@add_start_docstrings(
|
| 2313 |
"""
|
| 2314 |
-
|
| 2315 |
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 2316 |
""",
|
| 2317 |
BERT_START_DOCSTRING,
|
| 2318 |
)
|
| 2319 |
-
class
|
| 2320 |
def __init__(self, config):
|
| 2321 |
super().__init__(config)
|
| 2322 |
self.num_labels = config.num_labels
|
| 2323 |
|
| 2324 |
-
self.bert =
|
| 2325 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 2326 |
|
| 2327 |
# Initialize weights and apply final processing
|
|
|
|
| 54 |
logging,
|
| 55 |
replace_return_docstrings,
|
| 56 |
)
|
| 57 |
+
from .configuration_bert import JinaBertConfig
|
| 58 |
|
| 59 |
try:
|
| 60 |
from tqdm.autonotebook import trange
|
|
|
|
| 66 |
logger = logging.get_logger(__name__)
|
| 67 |
|
| 68 |
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
| 69 |
+
_CONFIG_FOR_DOC = "JinaBertConfig"
|
| 70 |
|
| 71 |
# TokenClassification docstring
|
| 72 |
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = (
|
|
|
|
| 197 |
return model
|
| 198 |
|
| 199 |
|
| 200 |
+
class JinaBertEmbeddings(nn.Module):
|
| 201 |
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 202 |
|
| 203 |
+
def __init__(self, config: JinaBertConfig):
|
| 204 |
super().__init__()
|
| 205 |
self.word_embeddings = nn.Embedding(
|
| 206 |
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
|
|
|
| 280 |
return embeddings
|
| 281 |
|
| 282 |
|
| 283 |
+
class JinaBertSelfAttention(nn.Module):
|
| 284 |
def __init__(self, config, position_embedding_type=None):
|
| 285 |
super().__init__()
|
| 286 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
|
|
|
| 448 |
return outputs
|
| 449 |
|
| 450 |
|
| 451 |
+
class JinaBertSelfOutput(nn.Module):
|
| 452 |
def __init__(self, config):
|
| 453 |
super().__init__()
|
| 454 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
|
| 464 |
return hidden_states
|
| 465 |
|
| 466 |
|
| 467 |
+
class JinaBertAttention(nn.Module):
|
| 468 |
def __init__(self, config, position_embedding_type=None):
|
| 469 |
super().__init__()
|
| 470 |
+
self.self = JinaBertSelfAttention(
|
| 471 |
config, position_embedding_type=position_embedding_type
|
| 472 |
)
|
| 473 |
+
self.output = JinaBertSelfOutput(config)
|
| 474 |
self.pruned_heads = set()
|
| 475 |
|
| 476 |
def prune_heads(self, heads):
|
|
|
|
| 524 |
return outputs
|
| 525 |
|
| 526 |
|
| 527 |
+
class JinaBertIntermediate(nn.Module):
|
| 528 |
def __init__(self, config):
|
| 529 |
super().__init__()
|
| 530 |
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
|
|
| 539 |
return hidden_states
|
| 540 |
|
| 541 |
|
| 542 |
+
class JinaBertOutput(nn.Module):
|
| 543 |
+
def __init__(self, config: JinaBertConfig):
|
| 544 |
super().__init__()
|
| 545 |
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 546 |
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
| 555 |
return hidden_states
|
| 556 |
|
| 557 |
|
| 558 |
+
class JinaBertGLUMLP(nn.Module):
|
| 559 |
+
def __init__(self, config: JinaBertConfig):
|
| 560 |
super().__init__()
|
| 561 |
self.config = config
|
| 562 |
self.gated_layers = nn.Linear(
|
|
|
|
| 589 |
return hidden_states
|
| 590 |
|
| 591 |
|
| 592 |
+
class JinaBertLayer(nn.Module):
|
| 593 |
+
def __init__(self, config: JinaBertConfig):
|
| 594 |
super().__init__()
|
| 595 |
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 596 |
self.seq_len_dim = 1
|
| 597 |
+
self.attention = JinaBertAttention(config)
|
| 598 |
self.is_decoder = config.is_decoder
|
| 599 |
self.add_cross_attention = config.add_cross_attention
|
| 600 |
self.feed_forward_type = config.feed_forward_type
|
|
|
|
| 603 |
raise ValueError(
|
| 604 |
f"{self} should be used as a decoder model if cross attention is added"
|
| 605 |
)
|
| 606 |
+
self.crossattention = JinaBertAttention(
|
| 607 |
config, position_embedding_type="absolute"
|
| 608 |
)
|
| 609 |
if self.feed_forward_type.endswith('glu'):
|
| 610 |
+
self.mlp = JinaBertGLUMLP(config)
|
| 611 |
else:
|
| 612 |
+
self.intermediate = JinaBertIntermediate(config)
|
| 613 |
+
self.output = JinaBertOutput(config)
|
| 614 |
|
| 615 |
def forward(
|
| 616 |
self,
|
|
|
|
| 699 |
return layer_output
|
| 700 |
|
| 701 |
|
| 702 |
+
class JinaBertEncoder(nn.Module):
|
| 703 |
+
def __init__(self, config: JinaBertConfig):
|
| 704 |
super().__init__()
|
| 705 |
self.config = config
|
| 706 |
self.layer = nn.ModuleList(
|
| 707 |
+
[JinaBertLayer(config) for _ in range(config.num_hidden_layers)]
|
| 708 |
)
|
| 709 |
self.gradient_checkpointing = False
|
| 710 |
self.num_attention_heads = config.num_attention_heads
|
|
|
|
| 724 |
# will be applied, it is necessary to construct the diagonal mask.
|
| 725 |
n_heads = self.num_attention_heads
|
| 726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
| 728 |
def get_slopes_power_of_2(n):
|
| 729 |
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
|
|
| 873 |
)
|
| 874 |
|
| 875 |
|
| 876 |
+
class JinaBertPooler(nn.Module):
|
| 877 |
def __init__(self, config):
|
| 878 |
super().__init__()
|
| 879 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
|
| 888 |
return pooled_output
|
| 889 |
|
| 890 |
|
| 891 |
+
class JinaBertPredictionHeadTransform(nn.Module):
|
| 892 |
def __init__(self, config):
|
| 893 |
super().__init__()
|
| 894 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
|
| 905 |
return hidden_states
|
| 906 |
|
| 907 |
|
| 908 |
+
class JinaBertLMPredictionHead(nn.Module):
|
| 909 |
def __init__(self, config):
|
| 910 |
super().__init__()
|
| 911 |
+
self.transform = JinaBertPredictionHeadTransform(config)
|
| 912 |
|
| 913 |
# The output weights are the same as the input embeddings, but there is
|
| 914 |
# an output-only bias for each token.
|
|
|
|
| 925 |
return hidden_states
|
| 926 |
|
| 927 |
|
| 928 |
+
class JinaBertOnlyMLMHead(nn.Module):
|
| 929 |
def __init__(self, config):
|
| 930 |
super().__init__()
|
| 931 |
+
self.predictions = JinaBertLMPredictionHead(config)
|
| 932 |
|
| 933 |
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
| 934 |
prediction_scores = self.predictions(sequence_output)
|
| 935 |
return prediction_scores
|
| 936 |
|
| 937 |
|
| 938 |
+
class JinaBertOnlyNSPHead(nn.Module):
|
| 939 |
def __init__(self, config):
|
| 940 |
super().__init__()
|
| 941 |
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
|
|
|
| 945 |
return seq_relationship_score
|
| 946 |
|
| 947 |
|
| 948 |
+
class JinaBertPreTrainingHeads(nn.Module):
|
| 949 |
def __init__(self, config):
|
| 950 |
super().__init__()
|
| 951 |
+
self.predictions = JinaBertLMPredictionHead(config)
|
| 952 |
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 953 |
|
| 954 |
def forward(self, sequence_output, pooled_output):
|
|
|
|
| 957 |
return prediction_scores, seq_relationship_score
|
| 958 |
|
| 959 |
|
| 960 |
+
class JinaBertPreTrainedModel(PreTrainedModel):
|
| 961 |
"""
|
| 962 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 963 |
models.
|
| 964 |
"""
|
| 965 |
|
| 966 |
+
config_class = JinaBertConfig
|
| 967 |
load_tf_weights = load_tf_weights_in_bert
|
| 968 |
base_model_prefix = "bert"
|
| 969 |
supports_gradient_checkpointing = True
|
|
|
|
| 985 |
module.weight.data.fill_(1.0)
|
| 986 |
|
| 987 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 988 |
+
if isinstance(module, JinaBertEncoder):
|
| 989 |
module.gradient_checkpointing = value
|
| 990 |
|
| 991 |
|
| 992 |
@dataclass
|
| 993 |
+
class JinaBertForPreTrainingOutput(ModelOutput):
|
| 994 |
"""
|
| 995 |
Output type of [`BertForPreTraining`].
|
| 996 |
|
|
|
|
| 1093 |
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1094 |
BERT_START_DOCSTRING,
|
| 1095 |
)
|
| 1096 |
+
class JinaBertModel(JinaBertPreTrainedModel):
|
| 1097 |
"""
|
| 1098 |
|
| 1099 |
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
|
|
|
| 1106 |
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 1107 |
"""
|
| 1108 |
|
| 1109 |
+
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|
| 1110 |
super().__init__(config)
|
| 1111 |
self.config = config
|
| 1112 |
|
|
|
|
| 1117 |
|
| 1118 |
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
| 1119 |
|
| 1120 |
+
self.embeddings = JinaBertEmbeddings(config)
|
| 1121 |
+
self.encoder = JinaBertEncoder(config)
|
| 1122 |
|
| 1123 |
+
self.pooler = JinaBertPooler(config) if add_pooling_layer else None
|
| 1124 |
|
| 1125 |
# Initialize weights and apply final processing
|
| 1126 |
self.post_init()
|
| 1127 |
|
| 1128 |
@torch.inference_mode()
|
| 1129 |
def encode(
|
| 1130 |
+
self: 'JinaBertModel',
|
| 1131 |
sentences: Union[str, List[str]],
|
| 1132 |
batch_size: int = 32,
|
| 1133 |
show_progress_bar: Optional[bool] = None,
|
|
|
|
| 1459 |
""",
|
| 1460 |
BERT_START_DOCSTRING,
|
| 1461 |
)
|
| 1462 |
+
class JinaBertForPreTraining(JinaBertPreTrainedModel):
|
| 1463 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1464 |
|
| 1465 |
def __init__(self, config):
|
| 1466 |
super().__init__(config)
|
| 1467 |
|
| 1468 |
+
self.bert = JinaBertModel(config)
|
| 1469 |
+
self.cls = JinaBertPreTrainingHeads(config)
|
| 1470 |
|
| 1471 |
# Initialize weights and apply final processing
|
| 1472 |
self.post_init()
|
|
|
|
| 1481 |
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
| 1482 |
)
|
| 1483 |
@replace_return_docstrings(
|
| 1484 |
+
output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
| 1485 |
)
|
| 1486 |
def forward(
|
| 1487 |
self,
|
|
|
|
| 1496 |
output_attentions: Optional[bool] = None,
|
| 1497 |
output_hidden_states: Optional[bool] = None,
|
| 1498 |
return_dict: Optional[bool] = None,
|
| 1499 |
+
) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
|
| 1500 |
r"""
|
| 1501 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1502 |
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
|
|
|
| 1512 |
Used to hide legacy arguments that have been deprecated.
|
| 1513 |
|
| 1514 |
Returns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1515 |
"""
|
| 1516 |
return_dict = (
|
| 1517 |
return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
| 1549 |
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
| 1550 |
return ((total_loss,) + output) if total_loss is not None else output
|
| 1551 |
|
| 1552 |
+
return JinaBertForPreTrainingOutput(
|
| 1553 |
loss=total_loss,
|
| 1554 |
prediction_logits=prediction_scores,
|
| 1555 |
seq_relationship_logits=seq_relationship_score,
|
|
|
|
| 1559 |
|
| 1560 |
|
| 1561 |
@add_start_docstrings(
|
| 1562 |
+
"""JinaBert Model with a `language modeling` head on top for CLM fine-tuning.""",
|
| 1563 |
BERT_START_DOCSTRING,
|
| 1564 |
)
|
| 1565 |
+
class JinaBertLMHeadModel(JinaBertPreTrainedModel):
|
| 1566 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1567 |
|
| 1568 |
def __init__(self, config):
|
|
|
|
| 1570 |
|
| 1571 |
if not config.is_decoder:
|
| 1572 |
logger.warning(
|
| 1573 |
+
"If you want to use `JinaBertLMHeadModel` as a standalone, add `is_decoder=True.`"
|
| 1574 |
)
|
| 1575 |
|
| 1576 |
+
self.bert = JinaBertModel(config, add_pooling_layer=False)
|
| 1577 |
+
self.cls = JinaBertOnlyMLMHead(config)
|
| 1578 |
|
| 1579 |
# Initialize weights and apply final processing
|
| 1580 |
self.post_init()
|
|
|
|
| 1719 |
|
| 1720 |
|
| 1721 |
@add_start_docstrings(
|
| 1722 |
+
"""JinaBert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING
|
| 1723 |
)
|
| 1724 |
+
class JinaBertForMaskedLM(JinaBertPreTrainedModel):
|
| 1725 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1726 |
|
| 1727 |
def __init__(self, config):
|
|
|
|
| 1729 |
|
| 1730 |
if config.is_decoder:
|
| 1731 |
logger.warning(
|
| 1732 |
+
"If you want to use `JinaBertForMaskedLM` make sure `config.is_decoder=False` for "
|
| 1733 |
"bi-directional self-attention."
|
| 1734 |
)
|
| 1735 |
|
| 1736 |
+
self.bert = JinaBertModel(config, add_pooling_layer=False)
|
| 1737 |
+
self.cls = JinaBertOnlyMLMHead(config)
|
| 1738 |
|
| 1739 |
# Initialize weights and apply final processing
|
| 1740 |
self.post_init()
|
|
|
|
| 1844 |
|
| 1845 |
|
| 1846 |
@add_start_docstrings(
|
| 1847 |
+
"""JinaBert Model with a `next sentence prediction (classification)` head on top.""",
|
| 1848 |
BERT_START_DOCSTRING,
|
| 1849 |
)
|
| 1850 |
+
class JinaBertForNextSentencePrediction(JinaBertPreTrainedModel):
|
| 1851 |
def __init__(self, config):
|
| 1852 |
super().__init__(config)
|
| 1853 |
|
| 1854 |
+
self.bert = JinaBertModel(config)
|
| 1855 |
+
self.cls = JinaBertOnlyNSPHead(config)
|
| 1856 |
|
| 1857 |
# Initialize weights and apply final processing
|
| 1858 |
self.post_init()
|
|
|
|
| 1886 |
- 1 indicates sequence B is a random sequence.
|
| 1887 |
|
| 1888 |
Returns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1889 |
"""
|
| 1890 |
|
| 1891 |
if "next_sentence_label" in kwargs:
|
|
|
|
| 1941 |
|
| 1942 |
@add_start_docstrings(
|
| 1943 |
"""
|
| 1944 |
+
JinaBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 1945 |
output) e.g. for GLUE tasks.
|
| 1946 |
""",
|
| 1947 |
BERT_START_DOCSTRING,
|
| 1948 |
)
|
| 1949 |
+
class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
|
| 1950 |
def __init__(self, config):
|
| 1951 |
super().__init__(config)
|
| 1952 |
self.num_labels = config.num_labels
|
| 1953 |
self.config = config
|
| 1954 |
|
| 1955 |
+
self.bert = JinaBertModel(config)
|
| 1956 |
classifier_dropout = (
|
| 1957 |
config.classifier_dropout
|
| 1958 |
if config.classifier_dropout is not None
|
|
|
|
| 2052 |
|
| 2053 |
@add_start_docstrings(
|
| 2054 |
"""
|
| 2055 |
+
JinaBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 2056 |
softmax) e.g. for RocStories/SWAG tasks.
|
| 2057 |
""",
|
| 2058 |
BERT_START_DOCSTRING,
|
| 2059 |
)
|
| 2060 |
+
class JinaBertForMultipleChoice(JinaBertPreTrainedModel):
|
| 2061 |
def __init__(self, config):
|
| 2062 |
super().__init__(config)
|
| 2063 |
|
| 2064 |
+
self.bert = JinaBertModel(config)
|
| 2065 |
classifier_dropout = (
|
| 2066 |
config.classifier_dropout
|
| 2067 |
if config.classifier_dropout is not None
|
|
|
|
| 2168 |
|
| 2169 |
@add_start_docstrings(
|
| 2170 |
"""
|
| 2171 |
+
JinaBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 2172 |
Named-Entity-Recognition (NER) tasks.
|
| 2173 |
""",
|
| 2174 |
BERT_START_DOCSTRING,
|
| 2175 |
)
|
| 2176 |
+
class JinaBertForTokenClassification(JinaBertPreTrainedModel):
|
| 2177 |
def __init__(self, config):
|
| 2178 |
super().__init__(config)
|
| 2179 |
self.num_labels = config.num_labels
|
| 2180 |
|
| 2181 |
+
self.bert = JinaBertModel(config, add_pooling_layer=False)
|
| 2182 |
classifier_dropout = (
|
| 2183 |
config.classifier_dropout
|
| 2184 |
if config.classifier_dropout is not None
|
|
|
|
| 2257 |
|
| 2258 |
@add_start_docstrings(
|
| 2259 |
"""
|
| 2260 |
+
JinaBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 2261 |
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 2262 |
""",
|
| 2263 |
BERT_START_DOCSTRING,
|
| 2264 |
)
|
| 2265 |
+
class JinaBertForQuestionAnswering(JinaBertPreTrainedModel):
|
| 2266 |
def __init__(self, config):
|
| 2267 |
super().__init__(config)
|
| 2268 |
self.num_labels = config.num_labels
|
| 2269 |
|
| 2270 |
+
self.bert = JinaBertModel(config, add_pooling_layer=False)
|
| 2271 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 2272 |
|
| 2273 |
# Initialize weights and apply final processing
|