Spaces:
Runtime error
Runtime error
Model can be loaded from local directory (#69)
Browse files
audiocraft/models/loaders.py
CHANGED
|
@@ -51,6 +51,10 @@ def _get_state_dict(
|
|
| 51 |
if os.path.isfile(file_or_url_or_id):
|
| 52 |
return torch.load(file_or_url_or_id, map_location=device)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
elif file_or_url_or_id.startswith('https://'):
|
| 55 |
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
| 56 |
|
|
|
|
| 51 |
if os.path.isfile(file_or_url_or_id):
|
| 52 |
return torch.load(file_or_url_or_id, map_location=device)
|
| 53 |
|
| 54 |
+
if os.path.isdir(file_or_url_or_id):
|
| 55 |
+
file = f"{file_or_url_or_id}/{filename}"
|
| 56 |
+
return torch.load(file, map_location=device)
|
| 57 |
+
|
| 58 |
elif file_or_url_or_id.startswith('https://'):
|
| 59 |
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
| 60 |
|
audiocraft/models/musicgen.py
CHANGED
|
@@ -89,10 +89,11 @@ class MusicGen:
|
|
| 89 |
return MusicGen(name, compression_model, lm)
|
| 90 |
|
| 91 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
| 98 |
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
|
|
|
| 89 |
return MusicGen(name, compression_model, lm)
|
| 90 |
|
| 91 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
| 92 |
+
if not os.path.isfile(name) and not os.path.isdir(name):
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"{name} is not a valid checkpoint name. "
|
| 95 |
+
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
| 96 |
+
)
|
| 97 |
|
| 98 |
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
| 99 |
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|