|
import torch |
|
import yaml |
|
from audiosr import download_checkpoint, default_audioldm_config, LatentDiffusion |
|
|
|
|
|
def load_audiosr(ckpt_path=None, config=None, device=None, model_name="basic"): |
|
if device is None or device == "auto": |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
elif torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
print("Loading AudioSR: %s" % model_name) |
|
print("Loading model on %s" % device) |
|
|
|
ckpt_path = download_checkpoint(model_name) |
|
|
|
if config is not None: |
|
assert type(config) is str |
|
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) |
|
else: |
|
config = default_audioldm_config(model_name) |
|
|
|
|
|
config["model"]["params"]["device"] = device |
|
|
|
|
|
|
|
latent_diffusion = LatentDiffusion(**config["model"]["params"]) |
|
|
|
resume_from_checkpoint = ckpt_path |
|
|
|
checkpoint = torch.load(resume_from_checkpoint, map_location="cpu") |
|
|
|
latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=True) |
|
|
|
latent_diffusion.eval() |
|
latent_diffusion = latent_diffusion.to(device) |
|
|
|
return latent_diffusion |
|
|