zRzRzRzRzRzRzR
commited on
Commit
•
37fe000
1
Parent(s):
37f2196
support transformers>=4.37.2 for finetuning
Browse files- modeling_chatglm.py +5 -4
modeling_chatglm.py
CHANGED
@@ -634,7 +634,8 @@ class GLMTransformer(torch.nn.Module):
|
|
634 |
attention_mask,
|
635 |
rotary_pos_emb,
|
636 |
kv_caches[index],
|
637 |
-
use_cache
|
|
|
638 |
)
|
639 |
else:
|
640 |
layer_ret = layer(
|
@@ -697,9 +698,9 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
697 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
698 |
return position_ids
|
699 |
|
700 |
-
def
|
701 |
-
if
|
702 |
-
|
703 |
|
704 |
|
705 |
class Embedding(torch.nn.Module):
|
|
|
634 |
attention_mask,
|
635 |
rotary_pos_emb,
|
636 |
kv_caches[index],
|
637 |
+
use_cache,
|
638 |
+
use_reentrant=False
|
639 |
)
|
640 |
else:
|
641 |
layer_ret = layer(
|
|
|
698 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
699 |
return position_ids
|
700 |
|
701 |
+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
702 |
+
if not self.supports_gradient_checkpointing:
|
703 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
704 |
|
705 |
|
706 |
class Embedding(torch.nn.Module):
|