Spaces:
Running
Running
Update model/utils.py
Browse files- model/utils.py +14 -28
model/utils.py
CHANGED
|
@@ -557,38 +557,24 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
| 557 |
|
| 558 |
# load model checkpoint for inference
|
| 559 |
|
| 560 |
-
def load_checkpoint(model, ckpt_path, device,
|
| 561 |
-
|
| 562 |
-
dtype = (
|
| 563 |
-
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
| 564 |
-
)
|
| 565 |
-
model = model.to(dtype)
|
| 566 |
|
| 567 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 568 |
if ckpt_type == "safetensors":
|
| 569 |
from safetensors.torch import load_file
|
| 570 |
-
|
| 571 |
-
checkpoint = load_file(ckpt_path)
|
| 572 |
else:
|
| 573 |
-
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=
|
| 574 |
-
|
|
|
|
|
|
|
| 575 |
if ckpt_type == "safetensors":
|
| 576 |
-
checkpoint
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
if k not in ["initted", "step"]
|
| 581 |
-
}
|
| 582 |
-
|
| 583 |
-
# patch for backward compatibility, 305e3ea
|
| 584 |
-
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
| 585 |
-
if key in checkpoint["model_state_dict"]:
|
| 586 |
-
del checkpoint["model_state_dict"][key]
|
| 587 |
-
|
| 588 |
-
model.load_state_dict(checkpoint["model_state_dict"])
|
| 589 |
else:
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
return model.to(device)
|
|
|
|
| 557 |
|
| 558 |
# load model checkpoint for inference
|
| 559 |
|
| 560 |
+
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
| 561 |
+
from ema_pytorch import EMA
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 564 |
if ckpt_type == "safetensors":
|
| 565 |
from safetensors.torch import load_file
|
| 566 |
+
checkpoint = load_file(ckpt_path, device=device)
|
|
|
|
| 567 |
else:
|
| 568 |
+
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
| 569 |
+
|
| 570 |
+
if use_ema == True:
|
| 571 |
+
ema_model = EMA(model, include_online_model = False).to(device)
|
| 572 |
if ckpt_type == "safetensors":
|
| 573 |
+
ema_model.load_state_dict(checkpoint)
|
| 574 |
+
else:
|
| 575 |
+
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
| 576 |
+
ema_model.copy_params_from_ema_to_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
else:
|
| 578 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 579 |
+
|
| 580 |
+
return model
|
|
|
|
|
|