Commit
·
2c40901
1
Parent(s):
ca9ce08
fix import err
Browse files- modeling_muddpythia.py +6 -5
modeling_muddpythia.py
CHANGED
@@ -9,10 +9,11 @@ from torch import Tensor
|
|
9 |
from torch.nn import functional as F
|
10 |
from torch.utils.checkpoint import checkpoint
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
16 |
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
18 |
|
@@ -451,4 +452,4 @@ def match_weight_muddpythia(model, w, strict=False, pythia=True):
|
|
451 |
v = v+1
|
452 |
state_dict[k] = torch.tensor(v)
|
453 |
model.load_state_dict(state_dict, strict=strict)
|
454 |
-
return model
|
|
|
9 |
from torch.nn import functional as F
|
10 |
from torch.utils.checkpoint import checkpoint
|
11 |
|
12 |
+
from .configuration_muddpythia import MUDDPythiaConfig
|
13 |
+
#try:
|
14 |
+
# from .configuration_muddpythia import MUDDPythiaConfig
|
15 |
+
#except:
|
16 |
+
# from configuration_muddpythia import MUDDPythiaConfig
|
17 |
|
18 |
from transformers.modeling_utils import PreTrainedModel
|
19 |
|
|
|
452 |
v = v+1
|
453 |
state_dict[k] = torch.tensor(v)
|
454 |
model.load_state_dict(state_dict, strict=strict)
|
455 |
+
return model
|