Update modeling_deberta.py
Browse files- modeling_deberta.py +1 -37
modeling_deberta.py
CHANGED
@@ -36,23 +36,11 @@ from transformers.modeling_outputs import (
|
|
36 |
from transformers.modeling_utils import PreTrainedModel
|
37 |
from transformers.pytorch_utils import softmax_backward_data
|
38 |
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
39 |
-
from
|
40 |
|
41 |
|
42 |
logger = logging.get_logger(__name__)
|
43 |
|
44 |
-
_CONFIG_FOR_DOC = "DebertaV2Config"
|
45 |
-
_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
|
46 |
-
_QA_TARGET_START_INDEX = 2
|
47 |
-
_QA_TARGET_END_INDEX = 9
|
48 |
-
|
49 |
-
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
50 |
-
"microsoft/deberta-v2-xlarge",
|
51 |
-
"microsoft/deberta-v2-xxlarge",
|
52 |
-
"microsoft/deberta-v2-xlarge-mnli",
|
53 |
-
"microsoft/deberta-v2-xxlarge-mnli",
|
54 |
-
]
|
55 |
-
|
56 |
|
57 |
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
|
58 |
class ContextPooler(nn.Module):
|
@@ -910,9 +898,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
|
|
910 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
911 |
models.
|
912 |
"""
|
913 |
-
|
914 |
-
config_class = DebertaV2Config
|
915 |
-
base_model_prefix = "deberta"
|
916 |
supports_gradient_checkpointing = True
|
917 |
|
918 |
def _init_weights(self, module):
|
@@ -1019,12 +1004,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
|
|
1019 |
"""
|
1020 |
raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
|
1021 |
|
1022 |
-
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1023 |
-
@add_code_sample_docstrings(
|
1024 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
1025 |
-
output_type=BaseModelOutput,
|
1026 |
-
config_class=_CONFIG_FOR_DOC,
|
1027 |
-
)
|
1028 |
def forward(
|
1029 |
self,
|
1030 |
input_ids: Optional[torch.Tensor] = None,
|
@@ -1128,14 +1107,6 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|
1128 |
def set_output_embeddings(self, new_embeddings):
|
1129 |
self.cls.predictions.decoder = new_embeddings
|
1130 |
|
1131 |
-
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1132 |
-
@add_code_sample_docstrings(
|
1133 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
1134 |
-
output_type=MaskedLMOutput,
|
1135 |
-
config_class=_CONFIG_FOR_DOC,
|
1136 |
-
mask="[MASK]",
|
1137 |
-
)
|
1138 |
-
# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM.forward with Deberta->DebertaV2
|
1139 |
def forward(
|
1140 |
self,
|
1141 |
input_ids: Optional[torch.Tensor] = None,
|
@@ -1246,13 +1217,6 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
1246 |
)
|
1247 |
return model_inputs
|
1248 |
|
1249 |
-
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1250 |
-
@add_code_sample_docstrings(
|
1251 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
1252 |
-
output_type=CausalLMOutput,
|
1253 |
-
config_class=_CONFIG_FOR_DOC,
|
1254 |
-
mask="[MASK]",
|
1255 |
-
)
|
1256 |
def forward(
|
1257 |
self,
|
1258 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
36 |
from transformers.modeling_utils import PreTrainedModel
|
37 |
from transformers.pytorch_utils import softmax_backward_data
|
38 |
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
39 |
+
from .configuration_deberta import DebertaV2Config
|
40 |
|
41 |
|
42 |
logger = logging.get_logger(__name__)
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
|
46 |
class ContextPooler(nn.Module):
|
|
|
898 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
899 |
models.
|
900 |
"""
|
|
|
|
|
|
|
901 |
supports_gradient_checkpointing = True
|
902 |
|
903 |
def _init_weights(self, module):
|
|
|
1004 |
"""
|
1005 |
raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
|
1006 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1007 |
def forward(
|
1008 |
self,
|
1009 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
1107 |
def set_output_embeddings(self, new_embeddings):
|
1108 |
self.cls.predictions.decoder = new_embeddings
|
1109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1110 |
def forward(
|
1111 |
self,
|
1112 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
1217 |
)
|
1218 |
return model_inputs
|
1219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1220 |
def forward(
|
1221 |
self,
|
1222 |
input_ids: Optional[torch.Tensor] = None,
|