Spaces:
Paused
Paused
calculating
commited on
Commit
·
4f90f1b
1
Parent(s):
896572e
committing...
Browse files- app.py +1 -1
- sample.wav +0 -0
- utils/dist.py +11 -12
app.py
CHANGED
|
@@ -14,7 +14,7 @@ import os
|
|
| 14 |
# Global variables for model and tokenizer
|
| 15 |
global_generator = None
|
| 16 |
global_tokenizer = None
|
| 17 |
-
default_audio_path = "
|
| 18 |
|
| 19 |
def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
|
| 20 |
"""Initialize the model and tokenizer"""
|
|
|
|
| 14 |
# Global variables for model and tokenizer
|
| 15 |
global_generator = None
|
| 16 |
global_tokenizer = None
|
| 17 |
+
default_audio_path = "sample.wav" # Changed from "testingtesting.wav"
|
| 18 |
|
| 19 |
def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
|
| 20 |
"""Initialize the model and tokenizer"""
|
sample.wav
ADDED
|
Binary file (786 kB). View file
|
|
|
utils/dist.py
CHANGED
|
@@ -8,6 +8,7 @@ import requests
|
|
| 8 |
import hashlib
|
| 9 |
|
| 10 |
from io import BytesIO
|
|
|
|
| 11 |
|
| 12 |
def rank0():
|
| 13 |
rank = os.environ.get('RANK')
|
|
@@ -75,17 +76,12 @@ def init_dist():
|
|
| 75 |
return rank, local_rank, world_size
|
| 76 |
|
| 77 |
def load_ckpt(load_from_location, expected_hash=None):
|
|
|
|
| 78 |
if local0():
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
save_path = f"
|
| 82 |
-
|
| 83 |
-
response = requests.get(url, stream=True)
|
| 84 |
-
total_size = int(response.headers.get('content-length', 0))
|
| 85 |
-
with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar:
|
| 86 |
-
for chunk in response.iter_content(chunk_size=8192):
|
| 87 |
-
f.write(chunk)
|
| 88 |
-
pbar.update(len(chunk))
|
| 89 |
if expected_hash is not None:
|
| 90 |
with open(save_path, 'rb') as f:
|
| 91 |
file_hash = hashlib.md5(f.read()).hexdigest()
|
|
@@ -94,6 +90,9 @@ def load_ckpt(load_from_location, expected_hash=None):
|
|
| 94 |
os.remove(save_path)
|
| 95 |
return load_ckpt(load_from_location, expected_hash)
|
| 96 |
if T.distributed.is_initialized():
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
return loaded
|
|
|
|
| 8 |
import hashlib
|
| 9 |
|
| 10 |
from io import BytesIO
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
|
| 13 |
def rank0():
|
| 14 |
rank = os.environ.get('RANK')
|
|
|
|
| 76 |
return rank, local_rank, world_size
|
| 77 |
|
| 78 |
def load_ckpt(load_from_location, expected_hash=None):
|
| 79 |
+
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub
|
| 80 |
if local0():
|
| 81 |
+
repo_id = "si-pbc/hertz-dev"
|
| 82 |
+
print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...')
|
| 83 |
+
save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt")
|
| 84 |
+
print0(f'Downloaded checkpoint to {save_path}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
if expected_hash is not None:
|
| 86 |
with open(save_path, 'rb') as f:
|
| 87 |
file_hash = hashlib.md5(f.read()).hexdigest()
|
|
|
|
| 90 |
os.remove(save_path)
|
| 91 |
return load_ckpt(load_from_location, expected_hash)
|
| 92 |
if T.distributed.is_initialized():
|
| 93 |
+
save_path = [save_path]
|
| 94 |
+
T.distributed.broadcast_object_list(save_path, src=0)
|
| 95 |
+
save_path = save_path[0]
|
| 96 |
+
loaded = T.load(save_path, weights_only=False, map_location='cpu')
|
| 97 |
+
print0(f'Loaded checkpoint from {save_path}')
|
| 98 |
return loaded
|