Spaces:
Running
Running
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- model/cfm.py +2 -1
- model/utils.py +2 -1
model/cfm.py
CHANGED
|
@@ -96,7 +96,8 @@ class CFM(nn.Module):
|
|
| 96 |
):
|
| 97 |
self.eval()
|
| 98 |
|
| 99 |
-
cond
|
|
|
|
| 100 |
|
| 101 |
# raw wave
|
| 102 |
|
|
|
|
| 96 |
):
|
| 97 |
self.eval()
|
| 98 |
|
| 99 |
+
if cond.device != torch.device('cpu'):
|
| 100 |
+
cond = cond.half()
|
| 101 |
|
| 102 |
# raw wave
|
| 103 |
|
model/utils.py
CHANGED
|
@@ -555,7 +555,8 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
| 555 |
# load model checkpoint for inference
|
| 556 |
|
| 557 |
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
| 558 |
-
|
|
|
|
| 559 |
|
| 560 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 561 |
if ckpt_type == "safetensors":
|
|
|
|
| 555 |
# load model checkpoint for inference
|
| 556 |
|
| 557 |
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
| 558 |
+
if device != "cpu":
|
| 559 |
+
model = model.half()
|
| 560 |
|
| 561 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 562 |
if ckpt_type == "safetensors":
|