import os import warnings import torch import numpy as np import soundfile as sf def get_device(tensor_or_module, default=None): if hasattr(tensor_or_module, "device"): return tensor_or_module.device elif hasattr(tensor_or_module, "parameters"): return next(tensor_or_module.parameters()).device elif default is None: raise TypeError( f"Don't know how to get device of {type(tensor_or_module)} object" ) else: return torch.device(default) class Separator: def forward_wav(self, wav, **kwargs): raise NotImplementedError def sample_rate(self): raise NotImplementedError def separate(model, wav, **kwargs): if isinstance(wav, np.ndarray): return numpy_separate(model, wav, **kwargs) elif isinstance(wav, torch.Tensor): return torch_separate(model, wav, **kwargs) else: raise ValueError( f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}" ) @torch.no_grad() def torch_separate(model: Separator, wav: torch.Tensor, **kwargs) -> torch.Tensor: """Core logic of `separate`.""" if model.in_channels is not None and wav.shape[-2] != model.in_channels: raise RuntimeError( f"Model supports {model.in_channels}-channel inputs but found audio with {wav.shape[-2]} channels." f"Please match the number of channels." ) # Handle device placement input_device = get_device(wav, default="cpu") model_device = get_device(model, default="cpu") wav = wav.to(model_device) # Forward separate_func = getattr(model, "forward_wav", model) out_wavs = separate_func(wav, **kwargs) # FIXME: for now this is the best we can do. out_wavs *= wav.abs().sum() / (out_wavs.abs().sum()) # Back to input device (and numpy if necessary) out_wavs = out_wavs.to(input_device) return out_wavs def numpy_separate(model: Separator, wav: np.ndarray, **kwargs) -> np.ndarray: """Numpy interface to `separate`.""" wav = torch.from_numpy(wav) out_wavs = torch_separate(model, wav, **kwargs) out_wavs = out_wavs.data.numpy() return out_wavs