Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import functools | |
| import hashlib | |
| import math | |
| import pathlib | |
| import tempfile | |
| import typing | |
| import warnings | |
| from collections import namedtuple | |
| from pathlib import Path | |
| import julius | |
| import numpy as np | |
| import soundfile | |
| import torch | |
| from . import util | |
| from .display import DisplayMixin | |
| from .dsp import DSPMixin | |
| from .effects import EffectMixin | |
| from .effects import ImpulseResponseMixin | |
| from .ffmpeg import FFMPEGMixin | |
| from .loudness import LoudnessMixin | |
| from .playback import PlayMixin | |
| from .whisper import WhisperMixin | |
| STFTParams = namedtuple( | |
| "STFTParams", | |
| ["window_length", "hop_length", "window_type", "match_stride", "padding_type"], | |
| ) | |
| """ | |
| STFTParams object is a container that holds STFT parameters - window_length, | |
| hop_length, and window_type. Not all parameters need to be specified. Ones that | |
| are not specified will be inferred by the AudioSignal parameters. | |
| Parameters | |
| ---------- | |
| window_length : int, optional | |
| Window length of STFT, by default ``0.032 * self.sample_rate``. | |
| hop_length : int, optional | |
| Hop length of STFT, by default ``window_length // 4``. | |
| window_type : str, optional | |
| Type of window to use, by default ``sqrt\_hann``. | |
| match_stride : bool, optional | |
| Whether to match the stride of convolutional layers, by default False | |
| padding_type : str, optional | |
| Type of padding to use, by default 'reflect' | |
| """ | |
| STFTParams.__new__.__defaults__ = (None, None, None, None, None) | |
| class AudioSignal( | |
| EffectMixin, | |
| LoudnessMixin, | |
| PlayMixin, | |
| ImpulseResponseMixin, | |
| DSPMixin, | |
| DisplayMixin, | |
| FFMPEGMixin, | |
| WhisperMixin, | |
| ): | |
| """This is the core object of this library. Audio is always | |
| loaded into an AudioSignal, which then enables all the features | |
| of this library, including audio augmentations, I/O, playback, | |
| and more. | |
| The structure of this object is that the base functionality | |
| is defined in ``core/audio_signal.py``, while extensions to | |
| that functionality are defined in the other ``core/*.py`` | |
| files. For example, all the display-based functionality | |
| (e.g. plot spectrograms, waveforms, write to tensorboard) | |
| are in ``core/display.py``. | |
| Parameters | |
| ---------- | |
| audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray] | |
| Object to create AudioSignal from. Can be a tensor, numpy array, | |
| or a path to a file. The file is always reshaped to | |
| sample_rate : int, optional | |
| Sample rate of the audio. If different from underlying file, resampling is | |
| performed. If passing in an array or tensor, this must be defined, | |
| by default None | |
| stft_params : STFTParams, optional | |
| Parameters of STFT to use. , by default None | |
| offset : float, optional | |
| Offset in seconds to read from file, by default 0 | |
| duration : float, optional | |
| Duration in seconds to read from file, by default None | |
| device : str, optional | |
| Device to load audio onto, by default None | |
| Examples | |
| -------- | |
| Loading an AudioSignal from an array, at a sample rate of | |
| 44100. | |
| >>> signal = AudioSignal(torch.randn(5*44100), 44100) | |
| Note, the signal is reshaped to have a batch size, and one | |
| audio channel: | |
| >>> print(signal.shape) | |
| (1, 1, 44100) | |
| You can treat AudioSignals like tensors, and many of the same | |
| functions you might use on tensors are defined for AudioSignals | |
| as well: | |
| >>> signal.to("cuda") | |
| >>> signal.cuda() | |
| >>> signal.clone() | |
| >>> signal.detach() | |
| Indexing AudioSignals returns an AudioSignal: | |
| >>> signal[..., 3*44100:4*44100] | |
| The above signal is 1 second long, and is also an AudioSignal. | |
| """ | |
| def __init__( | |
| self, | |
| audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray], | |
| sample_rate: int = None, | |
| stft_params: STFTParams = None, | |
| offset: float = 0, | |
| duration: float = None, | |
| device: str = None, | |
| ): | |
| audio_path = None | |
| audio_array = None | |
| if isinstance(audio_path_or_array, str): | |
| audio_path = audio_path_or_array | |
| elif isinstance(audio_path_or_array, pathlib.Path): | |
| audio_path = audio_path_or_array | |
| elif isinstance(audio_path_or_array, np.ndarray): | |
| audio_array = audio_path_or_array | |
| elif torch.is_tensor(audio_path_or_array): | |
| audio_array = audio_path_or_array | |
| else: | |
| raise ValueError( | |
| "audio_path_or_array must be either a Path, " | |
| "string, numpy array, or torch Tensor!" | |
| ) | |
| self.path_to_file = None | |
| self.audio_data = None | |
| self.sources = None # List of AudioSignal objects. | |
| self.stft_data = None | |
| if audio_path is not None: | |
| self.load_from_file( | |
| audio_path, offset=offset, duration=duration, device=device | |
| ) | |
| elif audio_array is not None: | |
| assert sample_rate is not None, "Must set sample rate!" | |
| self.load_from_array(audio_array, sample_rate, device=device) | |
| self.window = None | |
| self.stft_params = stft_params | |
| self.metadata = { | |
| "offset": offset, | |
| "duration": duration, | |
| } | |
| def path_to_input_file( | |
| self, | |
| ): | |
| """ | |
| Path to input file, if it exists. | |
| Alias to ``path_to_file`` for backwards compatibility | |
| """ | |
| return self.path_to_file | |
| def excerpt( | |
| cls, | |
| audio_path: typing.Union[str, Path], | |
| offset: float = None, | |
| duration: float = None, | |
| state: typing.Union[np.random.RandomState, int] = None, | |
| **kwargs, | |
| ): | |
| """Randomly draw an excerpt of ``duration`` seconds from an | |
| audio file specified at ``audio_path``, between ``offset`` seconds | |
| and end of file. ``state`` can be used to seed the random draw. | |
| Parameters | |
| ---------- | |
| audio_path : typing.Union[str, Path] | |
| Path to audio file to grab excerpt from. | |
| offset : float, optional | |
| Lower bound for the start time, in seconds drawn from | |
| the file, by default None. | |
| duration : float, optional | |
| Duration of excerpt, in seconds, by default None | |
| state : typing.Union[np.random.RandomState, int], optional | |
| RandomState or seed of random state, by default None | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal containing excerpt. | |
| Examples | |
| -------- | |
| >>> signal = AudioSignal.excerpt("path/to/audio", duration=5) | |
| """ | |
| info = util.info(audio_path) | |
| total_duration = info.duration | |
| state = util.random_state(state) | |
| lower_bound = 0 if offset is None else offset | |
| upper_bound = max(total_duration - duration, 0) | |
| offset = state.uniform(lower_bound, upper_bound) | |
| signal = cls(audio_path, offset=offset, duration=duration, **kwargs) | |
| signal.metadata["offset"] = offset | |
| signal.metadata["duration"] = duration | |
| return signal | |
| def salient_excerpt( | |
| cls, | |
| audio_path: typing.Union[str, Path], | |
| loudness_cutoff: float = None, | |
| num_tries: int = 8, | |
| state: typing.Union[np.random.RandomState, int] = None, | |
| **kwargs, | |
| ): | |
| """Similar to AudioSignal.excerpt, except it extracts excerpts only | |
| if they are above a specified loudness threshold, which is computed via | |
| a fast LUFS routine. | |
| Parameters | |
| ---------- | |
| audio_path : typing.Union[str, Path] | |
| Path to audio file to grab excerpt from. | |
| loudness_cutoff : float, optional | |
| Loudness threshold in dB. Typical values are ``-40, -60``, | |
| etc, by default None | |
| num_tries : int, optional | |
| Number of tries to grab an excerpt above the threshold | |
| before giving up, by default 8. | |
| state : typing.Union[np.random.RandomState, int], optional | |
| RandomState or seed of random state, by default None | |
| kwargs : dict | |
| Keyword arguments to AudioSignal.excerpt | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal containing excerpt. | |
| .. warning:: | |
| if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can | |
| result in an infinite loop if ``audio_path`` does not have | |
| any loud enough excerpts. | |
| Examples | |
| -------- | |
| >>> signal = AudioSignal.salient_excerpt( | |
| "path/to/audio", | |
| loudness_cutoff=-40, | |
| duration=5 | |
| ) | |
| """ | |
| state = util.random_state(state) | |
| if loudness_cutoff is None: | |
| excerpt = cls.excerpt(audio_path, state=state, **kwargs) | |
| else: | |
| loudness = -np.inf | |
| num_try = 0 | |
| while loudness <= loudness_cutoff: | |
| excerpt = cls.excerpt(audio_path, state=state, **kwargs) | |
| loudness = excerpt.loudness() | |
| num_try += 1 | |
| if num_tries is not None and num_try >= num_tries: | |
| break | |
| return excerpt | |
| def zeros( | |
| cls, | |
| duration: float, | |
| sample_rate: int, | |
| num_channels: int = 1, | |
| batch_size: int = 1, | |
| **kwargs, | |
| ): | |
| """Helper function create an AudioSignal of all zeros. | |
| Parameters | |
| ---------- | |
| duration : float | |
| Duration of AudioSignal | |
| sample_rate : int | |
| Sample rate of AudioSignal | |
| num_channels : int, optional | |
| Number of channels, by default 1 | |
| batch_size : int, optional | |
| Batch size, by default 1 | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal containing all zeros. | |
| Examples | |
| -------- | |
| Generate 5 seconds of all zeros at a sample rate of 44100. | |
| >>> signal = AudioSignal.zeros(5.0, 44100) | |
| """ | |
| n_samples = int(duration * sample_rate) | |
| return cls( | |
| torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs | |
| ) | |
| def wave( | |
| cls, | |
| frequency: float, | |
| duration: float, | |
| sample_rate: int, | |
| num_channels: int = 1, | |
| shape: str = "sine", | |
| **kwargs, | |
| ): | |
| """ | |
| Generate a waveform of a given frequency and shape. | |
| Parameters | |
| ---------- | |
| frequency : float | |
| Frequency of the waveform | |
| duration : float | |
| Duration of the waveform | |
| sample_rate : int | |
| Sample rate of the waveform | |
| num_channels : int, optional | |
| Number of channels, by default 1 | |
| shape : str, optional | |
| Shape of the waveform, by default "saw" | |
| One of "sawtooth", "square", "sine", "triangle" | |
| kwargs : dict | |
| Keyword arguments to AudioSignal | |
| """ | |
| n_samples = int(duration * sample_rate) | |
| t = torch.linspace(0, duration, n_samples) | |
| if shape == "sawtooth": | |
| from scipy.signal import sawtooth | |
| wave_data = sawtooth(2 * np.pi * frequency * t, 0.5) | |
| elif shape == "square": | |
| from scipy.signal import square | |
| wave_data = square(2 * np.pi * frequency * t) | |
| elif shape == "sine": | |
| wave_data = np.sin(2 * np.pi * frequency * t) | |
| elif shape == "triangle": | |
| from scipy.signal import sawtooth | |
| # frequency is doubled by the abs call, so omit the 2 in 2pi | |
| wave_data = sawtooth(np.pi * frequency * t, 0.5) | |
| wave_data = -np.abs(wave_data) * 2 + 1 | |
| else: | |
| raise ValueError(f"Invalid shape {shape}") | |
| wave_data = torch.tensor(wave_data, dtype=torch.float32) | |
| wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1) | |
| return cls(wave_data, sample_rate, **kwargs) | |
| def batch( | |
| cls, | |
| audio_signals: list, | |
| pad_signals: bool = False, | |
| truncate_signals: bool = False, | |
| resample: bool = False, | |
| dim: int = 0, | |
| ): | |
| """Creates a batched AudioSignal from a list of AudioSignals. | |
| Parameters | |
| ---------- | |
| audio_signals : list[AudioSignal] | |
| List of AudioSignal objects | |
| pad_signals : bool, optional | |
| Whether to pad signals to length of the maximum length | |
| AudioSignal in the list, by default False | |
| truncate_signals : bool, optional | |
| Whether to truncate signals to length of shortest length | |
| AudioSignal in the list, by default False | |
| resample : bool, optional | |
| Whether to resample AudioSignal to the sample rate of | |
| the first AudioSignal in the list, by default False | |
| dim : int, optional | |
| Dimension along which to batch the signals. | |
| Returns | |
| ------- | |
| AudioSignal | |
| Batched AudioSignal. | |
| Raises | |
| ------ | |
| RuntimeError | |
| If not all AudioSignals are the same sample rate, and | |
| ``resample=False``, an error is raised. | |
| RuntimeError | |
| If not all AudioSignals are the same the length, and | |
| both ``pad_signals=False`` and ``truncate_signals=False``, | |
| an error is raised. | |
| Examples | |
| -------- | |
| Batching a bunch of random signals: | |
| >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)] | |
| >>> signal = AudioSignal.batch(signal_list) | |
| >>> print(signal.shape) | |
| (10, 1, 44100) | |
| """ | |
| signal_lengths = [x.signal_length for x in audio_signals] | |
| sample_rates = [x.sample_rate for x in audio_signals] | |
| if len(set(sample_rates)) != 1: | |
| if resample: | |
| for x in audio_signals: | |
| x.resample(sample_rates[0]) | |
| else: | |
| raise RuntimeError( | |
| f"Not all signals had the same sample rate! Got {sample_rates}. " | |
| f"All signals must have the same sample rate, or resample must be True. " | |
| ) | |
| if len(set(signal_lengths)) != 1: | |
| if pad_signals: | |
| max_length = max(signal_lengths) | |
| for x in audio_signals: | |
| pad_len = max_length - x.signal_length | |
| x.zero_pad(0, pad_len) | |
| elif truncate_signals: | |
| min_length = min(signal_lengths) | |
| for x in audio_signals: | |
| x.truncate_samples(min_length) | |
| else: | |
| raise RuntimeError( | |
| f"Not all signals had the same length! Got {signal_lengths}. " | |
| f"All signals must be the same length, or pad_signals/truncate_signals " | |
| f"must be True. " | |
| ) | |
| # Concatenate along the specified dimension (default 0) | |
| audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim) | |
| audio_paths = [x.path_to_file for x in audio_signals] | |
| batched_signal = cls( | |
| audio_data, | |
| sample_rate=audio_signals[0].sample_rate, | |
| ) | |
| batched_signal.path_to_file = audio_paths | |
| return batched_signal | |
| # I/O | |
| def load_from_file( | |
| self, | |
| audio_path: typing.Union[str, Path], | |
| offset: float, | |
| duration: float, | |
| device: str = "cpu", | |
| ): | |
| """Loads data from file. Used internally when AudioSignal | |
| is instantiated with a path to a file. | |
| Parameters | |
| ---------- | |
| audio_path : typing.Union[str, Path] | |
| Path to file | |
| offset : float | |
| Offset in seconds | |
| duration : float | |
| Duration in seconds | |
| device : str, optional | |
| Device to put AudioSignal on, by default "cpu" | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal loaded from file | |
| """ | |
| import librosa | |
| data, sample_rate = librosa.load( | |
| audio_path, | |
| offset=offset, | |
| duration=duration, | |
| sr=None, | |
| mono=False, | |
| ) | |
| data = util.ensure_tensor(data) | |
| if data.shape[-1] == 0: | |
| raise RuntimeError( | |
| f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!" | |
| ) | |
| if data.ndim < 2: | |
| data = data.unsqueeze(0) | |
| if data.ndim < 3: | |
| data = data.unsqueeze(0) | |
| self.audio_data = data | |
| self.original_signal_length = self.signal_length | |
| self.sample_rate = sample_rate | |
| self.path_to_file = audio_path | |
| return self.to(device) | |
| def load_from_array( | |
| self, | |
| audio_array: typing.Union[torch.Tensor, np.ndarray], | |
| sample_rate: int, | |
| device: str = "cpu", | |
| ): | |
| """Loads data from array, reshaping it to be exactly 3 | |
| dimensions. Used internally when AudioSignal is called | |
| with a tensor or an array. | |
| Parameters | |
| ---------- | |
| audio_array : typing.Union[torch.Tensor, np.ndarray] | |
| Array/tensor of audio of samples. | |
| sample_rate : int | |
| Sample rate of audio | |
| device : str, optional | |
| Device to move audio onto, by default "cpu" | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal loaded from array | |
| """ | |
| audio_data = util.ensure_tensor(audio_array) | |
| if audio_data.dtype == torch.double: | |
| audio_data = audio_data.float() | |
| if audio_data.ndim < 2: | |
| audio_data = audio_data.unsqueeze(0) | |
| if audio_data.ndim < 3: | |
| audio_data = audio_data.unsqueeze(0) | |
| self.audio_data = audio_data | |
| self.original_signal_length = self.signal_length | |
| self.sample_rate = sample_rate | |
| return self.to(device) | |
| def write(self, audio_path: typing.Union[str, Path]): | |
| """Writes audio to a file. Only writes the audio | |
| that is in the very first item of the batch. To write other items | |
| in the batch, index the signal along the batch dimension | |
| before writing. After writing, the signal's ``path_to_file`` | |
| attribute is updated to the new path. | |
| Parameters | |
| ---------- | |
| audio_path : typing.Union[str, Path] | |
| Path to write audio to. | |
| Returns | |
| ------- | |
| AudioSignal | |
| Returns original AudioSignal, so you can use this in a fluent | |
| interface. | |
| Examples | |
| -------- | |
| Creating and writing a signal to disk: | |
| >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100) | |
| >>> signal.write("/tmp/out.wav") | |
| Writing a different element of the batch: | |
| >>> signal[5].write("/tmp/out.wav") | |
| Using this in a fluent interface: | |
| >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav") | |
| """ | |
| if self.audio_data[0].abs().max() > 1: | |
| warnings.warn("Audio amplitude > 1 clipped when saving") | |
| soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate) | |
| self.path_to_file = audio_path | |
| return self | |
| def deepcopy(self): | |
| """Copies the signal and all of its attributes. | |
| Returns | |
| ------- | |
| AudioSignal | |
| Deep copy of the audio signal. | |
| """ | |
| return copy.deepcopy(self) | |
| def copy(self): | |
| """Shallow copy of signal. | |
| Returns | |
| ------- | |
| AudioSignal | |
| Shallow copy of the audio signal. | |
| """ | |
| return copy.copy(self) | |
| def clone(self): | |
| """Clones all tensors contained in the AudioSignal, | |
| and returns a copy of the signal with everything | |
| cloned. Useful when using AudioSignal within autograd | |
| computation graphs. | |
| Relevant attributes are the stft data, the audio data, | |
| and the loudness of the file. | |
| Returns | |
| ------- | |
| AudioSignal | |
| Clone of AudioSignal. | |
| """ | |
| clone = type(self)( | |
| self.audio_data.clone(), | |
| self.sample_rate, | |
| stft_params=self.stft_params, | |
| ) | |
| if self.stft_data is not None: | |
| clone.stft_data = self.stft_data.clone() | |
| if self._loudness is not None: | |
| clone._loudness = self._loudness.clone() | |
| clone.path_to_file = copy.deepcopy(self.path_to_file) | |
| clone.metadata = copy.deepcopy(self.metadata) | |
| return clone | |
| def detach(self): | |
| """Detaches tensors contained in AudioSignal. | |
| Relevant attributes are the stft data, the audio data, | |
| and the loudness of the file. | |
| Returns | |
| ------- | |
| AudioSignal | |
| Same signal, but with all tensors detached. | |
| """ | |
| if self._loudness is not None: | |
| self._loudness = self._loudness.detach() | |
| if self.stft_data is not None: | |
| self.stft_data = self.stft_data.detach() | |
| self.audio_data = self.audio_data.detach() | |
| return self | |
| def hash(self): | |
| """Writes the audio data to a temporary file, and then | |
| hashes it using hashlib. Useful for creating a file | |
| name based on the audio content. | |
| Returns | |
| ------- | |
| str | |
| Hash of audio data. | |
| Examples | |
| -------- | |
| Creating a signal, and writing it to a unique file name: | |
| >>> signal = AudioSignal(torch.randn(44100), 44100) | |
| >>> hash = signal.hash() | |
| >>> signal.write(f"{hash}.wav") | |
| """ | |
| with tempfile.NamedTemporaryFile(suffix=".wav") as f: | |
| self.write(f.name) | |
| h = hashlib.sha256() | |
| b = bytearray(128 * 1024) | |
| mv = memoryview(b) | |
| with open(f.name, "rb", buffering=0) as f: | |
| for n in iter(lambda: f.readinto(mv), 0): | |
| h.update(mv[:n]) | |
| file_hash = h.hexdigest() | |
| return file_hash | |
| # Signal operations | |
| def to_mono(self): | |
| """Converts audio data to mono audio, by taking the mean | |
| along the channels dimension. | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal with mean of channels. | |
| """ | |
| self.audio_data = self.audio_data.mean(1, keepdim=True) | |
| return self | |
| def resample(self, sample_rate: int): | |
| """Resamples the audio, using sinc interpolation. This works on both | |
| cpu and gpu, and is much faster on gpu. | |
| Parameters | |
| ---------- | |
| sample_rate : int | |
| Sample rate to resample to. | |
| Returns | |
| ------- | |
| AudioSignal | |
| Resampled AudioSignal | |
| """ | |
| if sample_rate == self.sample_rate: | |
| return self | |
| self.audio_data = julius.resample_frac( | |
| self.audio_data, self.sample_rate, sample_rate | |
| ) | |
| self.sample_rate = sample_rate | |
| return self | |
| # Tensor operations | |
| def to(self, device: str): | |
| """Moves all tensors contained in signal to the specified device. | |
| Parameters | |
| ---------- | |
| device : str | |
| Device to move AudioSignal onto. Typical values are | |
| "cuda", "cpu", or "cuda:n" to specify the nth gpu. | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal with all tensors moved to specified device. | |
| """ | |
| if self._loudness is not None: | |
| self._loudness = self._loudness.to(device) | |
| if self.stft_data is not None: | |
| self.stft_data = self.stft_data.to(device) | |
| if self.audio_data is not None: | |
| self.audio_data = self.audio_data.to(device) | |
| return self | |
| def float(self): | |
| """Calls ``.float()`` on ``self.audio_data``. | |
| Returns | |
| ------- | |
| AudioSignal | |
| """ | |
| self.audio_data = self.audio_data.float() | |
| return self | |
| def cpu(self): | |
| """Moves AudioSignal to cpu. | |
| Returns | |
| ------- | |
| AudioSignal | |
| """ | |
| return self.to("cpu") | |
| def cuda(self): # pragma: no cover | |
| """Moves AudioSignal to cuda. | |
| Returns | |
| ------- | |
| AudioSignal | |
| """ | |
| return self.to("cuda") | |
| def numpy(self): | |
| """Detaches ``self.audio_data``, moves to cpu, and converts to numpy. | |
| Returns | |
| ------- | |
| np.ndarray | |
| Audio data as a numpy array. | |
| """ | |
| return self.audio_data.detach().cpu().numpy() | |
| def zero_pad(self, before: int, after: int): | |
| """Zero pads the audio_data tensor before and after. | |
| Parameters | |
| ---------- | |
| before : int | |
| How many zeros to prepend to audio. | |
| after : int | |
| How many zeros to append to audio. | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal with padding applied. | |
| """ | |
| self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after)) | |
| return self | |
| def zero_pad_to(self, length: int, mode: str = "after"): | |
| """Pad with zeros to a specified length, either before or after | |
| the audio data. | |
| Parameters | |
| ---------- | |
| length : int | |
| Length to pad to | |
| mode : str, optional | |
| Whether to prepend or append zeros to signal, by default "after" | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal with padding applied. | |
| """ | |
| if mode == "before": | |
| self.zero_pad(max(length - self.signal_length, 0), 0) | |
| elif mode == "after": | |
| self.zero_pad(0, max(length - self.signal_length, 0)) | |
| return self | |
| def trim(self, before: int, after: int): | |
| """Trims the audio_data tensor before and after. | |
| Parameters | |
| ---------- | |
| before : int | |
| How many samples to trim from beginning. | |
| after : int | |
| How many samples to trim from end. | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal with trimming applied. | |
| """ | |
| if after == 0: | |
| self.audio_data = self.audio_data[..., before:] | |
| else: | |
| self.audio_data = self.audio_data[..., before:-after] | |
| return self | |
| def truncate_samples(self, length_in_samples: int): | |
| """Truncate signal to specified length. | |
| Parameters | |
| ---------- | |
| length_in_samples : int | |
| Truncate to this many samples. | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal with truncation applied. | |
| """ | |
| self.audio_data = self.audio_data[..., :length_in_samples] | |
| return self | |
| def device(self): | |
| """Get device that AudioSignal is on. | |
| Returns | |
| ------- | |
| torch.device | |
| Device that AudioSignal is on. | |
| """ | |
| if self.audio_data is not None: | |
| device = self.audio_data.device | |
| elif self.stft_data is not None: | |
| device = self.stft_data.device | |
| return device | |
| # Properties | |
| def audio_data(self): | |
| """Returns the audio data tensor in the object. | |
| Audio data is always of the shape | |
| (batch_size, num_channels, num_samples). If value has less | |
| than 3 dims (e.g. is (num_channels, num_samples)), then it will | |
| be reshaped to (1, num_channels, num_samples) - a batch size of 1. | |
| Parameters | |
| ---------- | |
| data : typing.Union[torch.Tensor, np.ndarray] | |
| Audio data to set. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Audio samples. | |
| """ | |
| return self._audio_data | |
| def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]): | |
| if data is not None: | |
| assert torch.is_tensor(data), "audio_data should be torch.Tensor" | |
| assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)" | |
| self._audio_data = data | |
| # Old loudness value not guaranteed to be right, reset it. | |
| self._loudness = None | |
| return | |
| # alias for audio_data | |
| samples = audio_data | |
| def stft_data(self): | |
| """Returns the STFT data inside the signal. Shape is | |
| (batch, channels, frequencies, time). | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Complex spectrogram data. | |
| """ | |
| return self._stft_data | |
| def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]): | |
| if data is not None: | |
| assert torch.is_tensor(data) and torch.is_complex(data) | |
| if self.stft_data is not None and self.stft_data.shape != data.shape: | |
| warnings.warn("stft_data changed shape") | |
| self._stft_data = data | |
| return | |
| def batch_size(self): | |
| """Batch size of audio signal. | |
| Returns | |
| ------- | |
| int | |
| Batch size of signal. | |
| """ | |
| return self.audio_data.shape[0] | |
| def signal_length(self): | |
| """Length of audio signal. | |
| Returns | |
| ------- | |
| int | |
| Length of signal in samples. | |
| """ | |
| return self.audio_data.shape[-1] | |
| # alias for signal_length | |
| length = signal_length | |
| def shape(self): | |
| """Shape of audio data. | |
| Returns | |
| ------- | |
| tuple | |
| Shape of audio data. | |
| """ | |
| return self.audio_data.shape | |
| def signal_duration(self): | |
| """Length of audio signal in seconds. | |
| Returns | |
| ------- | |
| float | |
| Length of signal in seconds. | |
| """ | |
| return self.signal_length / self.sample_rate | |
| # alias for signal_duration | |
| duration = signal_duration | |
| def num_channels(self): | |
| """Number of audio channels. | |
| Returns | |
| ------- | |
| int | |
| Number of audio channels. | |
| """ | |
| return self.audio_data.shape[1] | |
| # STFT | |
| def get_window(window_type: str, window_length: int, device: str): | |
| """Wrapper around scipy.signal.get_window so one can also get the | |
| popular sqrt-hann window. This function caches for efficiency | |
| using functools.lru\_cache. | |
| Parameters | |
| ---------- | |
| window_type : str | |
| Type of window to get | |
| window_length : int | |
| Length of the window | |
| device : str | |
| Device to put window onto. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Window returned by scipy.signal.get_window, as a tensor. | |
| """ | |
| from scipy import signal | |
| if window_type == "average": | |
| window = np.ones(window_length) / window_length | |
| elif window_type == "sqrt_hann": | |
| window = np.sqrt(signal.get_window("hann", window_length)) | |
| else: | |
| window = signal.get_window(window_type, window_length) | |
| window = torch.from_numpy(window).to(device).float() | |
| return window | |
| def stft_params(self): | |
| """Returns STFTParams object, which can be re-used to other | |
| AudioSignals. | |
| This property can be set as well. If values are not defined in STFTParams, | |
| they are inferred automatically from the signal properties. The default is to use | |
| 32ms windows, with 8ms hop length, and the square root of the hann window. | |
| Returns | |
| ------- | |
| STFTParams | |
| STFT parameters for the AudioSignal. | |
| Examples | |
| -------- | |
| >>> stft_params = STFTParams(128, 32) | |
| >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params) | |
| >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params) | |
| >>> signal1.stft_params = STFTParams() # Defaults | |
| """ | |
| return self._stft_params | |
| def stft_params(self, value: STFTParams): | |
| default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate)))) | |
| default_hop_len = default_win_len // 4 | |
| default_win_type = "hann" | |
| default_match_stride = False | |
| default_padding_type = "reflect" | |
| default_stft_params = STFTParams( | |
| window_length=default_win_len, | |
| hop_length=default_hop_len, | |
| window_type=default_win_type, | |
| match_stride=default_match_stride, | |
| padding_type=default_padding_type, | |
| )._asdict() | |
| value = value._asdict() if value else default_stft_params | |
| for key in default_stft_params: | |
| if value[key] is None: | |
| value[key] = default_stft_params[key] | |
| self._stft_params = STFTParams(**value) | |
| self.stft_data = None | |
| def compute_stft_padding( | |
| self, window_length: int, hop_length: int, match_stride: bool | |
| ): | |
| """Compute how the STFT should be padded, based on match\_stride. | |
| Parameters | |
| ---------- | |
| window_length : int | |
| Window length of STFT. | |
| hop_length : int | |
| Hop length of STFT. | |
| match_stride : bool | |
| Whether or not to match stride, making the STFT have the same alignment as | |
| convolutional layers. | |
| Returns | |
| ------- | |
| tuple | |
| Amount to pad on either side of audio. | |
| """ | |
| length = self.signal_length | |
| if match_stride: | |
| assert ( | |
| hop_length == window_length // 4 | |
| ), "For match_stride, hop must equal n_fft // 4" | |
| right_pad = math.ceil(length / hop_length) * hop_length - length | |
| pad = (window_length - hop_length) // 2 | |
| else: | |
| right_pad = 0 | |
| pad = 0 | |
| return right_pad, pad | |
| def stft( | |
| self, | |
| window_length: int = None, | |
| hop_length: int = None, | |
| window_type: str = None, | |
| match_stride: bool = None, | |
| padding_type: str = None, | |
| ): | |
| """Computes the short-time Fourier transform of the audio data, | |
| with specified STFT parameters. | |
| Parameters | |
| ---------- | |
| window_length : int, optional | |
| Window length of STFT, by default ``0.032 * self.sample_rate``. | |
| hop_length : int, optional | |
| Hop length of STFT, by default ``window_length // 4``. | |
| window_type : str, optional | |
| Type of window to use, by default ``sqrt\_hann``. | |
| match_stride : bool, optional | |
| Whether to match the stride of convolutional layers, by default False | |
| padding_type : str, optional | |
| Type of padding to use, by default 'reflect' | |
| Returns | |
| ------- | |
| torch.Tensor | |
| STFT of audio data. | |
| Examples | |
| -------- | |
| Compute the STFT of an AudioSignal: | |
| >>> signal = AudioSignal(torch.randn(44100), 44100) | |
| >>> signal.stft() | |
| Vary the window and hop length: | |
| >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)] | |
| >>> for stft_param in stft_params: | |
| >>> signal.stft_params = stft_params | |
| >>> signal.stft() | |
| """ | |
| window_length = ( | |
| self.stft_params.window_length | |
| if window_length is None | |
| else int(window_length) | |
| ) | |
| hop_length = ( | |
| self.stft_params.hop_length if hop_length is None else int(hop_length) | |
| ) | |
| window_type = ( | |
| self.stft_params.window_type if window_type is None else window_type | |
| ) | |
| match_stride = ( | |
| self.stft_params.match_stride if match_stride is None else match_stride | |
| ) | |
| padding_type = ( | |
| self.stft_params.padding_type if padding_type is None else padding_type | |
| ) | |
| window = self.get_window(window_type, window_length, self.audio_data.device) | |
| window = window.to(self.audio_data.device) | |
| audio_data = self.audio_data | |
| right_pad, pad = self.compute_stft_padding( | |
| window_length, hop_length, match_stride | |
| ) | |
| audio_data = torch.nn.functional.pad( | |
| audio_data, (pad, pad + right_pad), padding_type | |
| ) | |
| stft_data = torch.stft( | |
| audio_data.reshape(-1, audio_data.shape[-1]), | |
| n_fft=window_length, | |
| hop_length=hop_length, | |
| window=window, | |
| return_complex=True, | |
| center=True, | |
| ) | |
| _, nf, nt = stft_data.shape | |
| stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt) | |
| if match_stride: | |
| # Drop first two and last two frames, which are added | |
| # because of padding. Now num_frames * hop_length = num_samples. | |
| stft_data = stft_data[..., 2:-2] | |
| self.stft_data = stft_data | |
| return stft_data | |
| def istft( | |
| self, | |
| window_length: int = None, | |
| hop_length: int = None, | |
| window_type: str = None, | |
| match_stride: bool = None, | |
| length: int = None, | |
| ): | |
| """Computes inverse STFT and sets it to audio\_data. | |
| Parameters | |
| ---------- | |
| window_length : int, optional | |
| Window length of STFT, by default ``0.032 * self.sample_rate``. | |
| hop_length : int, optional | |
| Hop length of STFT, by default ``window_length // 4``. | |
| window_type : str, optional | |
| Type of window to use, by default ``sqrt\_hann``. | |
| match_stride : bool, optional | |
| Whether to match the stride of convolutional layers, by default False | |
| length : int, optional | |
| Original length of signal, by default None | |
| Returns | |
| ------- | |
| AudioSignal | |
| AudioSignal with istft applied. | |
| Raises | |
| ------ | |
| RuntimeError | |
| Raises an error if stft was not called prior to istft on the signal, | |
| or if stft_data is not set. | |
| """ | |
| if self.stft_data is None: | |
| raise RuntimeError("Cannot do inverse STFT without self.stft_data!") | |
| window_length = ( | |
| self.stft_params.window_length | |
| if window_length is None | |
| else int(window_length) | |
| ) | |
| hop_length = ( | |
| self.stft_params.hop_length if hop_length is None else int(hop_length) | |
| ) | |
| window_type = ( | |
| self.stft_params.window_type if window_type is None else window_type | |
| ) | |
| match_stride = ( | |
| self.stft_params.match_stride if match_stride is None else match_stride | |
| ) | |
| window = self.get_window(window_type, window_length, self.stft_data.device) | |
| nb, nch, nf, nt = self.stft_data.shape | |
| stft_data = self.stft_data.reshape(nb * nch, nf, nt) | |
| right_pad, pad = self.compute_stft_padding( | |
| window_length, hop_length, match_stride | |
| ) | |
| if length is None: | |
| length = self.original_signal_length | |
| length = length + 2 * pad + right_pad | |
| if match_stride: | |
| # Zero-pad the STFT on either side, putting back the frames that were | |
| # dropped in stft(). | |
| stft_data = torch.nn.functional.pad(stft_data, (2, 2)) | |
| audio_data = torch.istft( | |
| stft_data, | |
| n_fft=window_length, | |
| hop_length=hop_length, | |
| window=window, | |
| length=length, | |
| center=True, | |
| ) | |
| audio_data = audio_data.reshape(nb, nch, -1) | |
| if match_stride: | |
| audio_data = audio_data[..., pad : -(pad + right_pad)] | |
| self.audio_data = audio_data | |
| return self | |
| def get_mel_filters( | |
| sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None | |
| ): | |
| """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins. | |
| Parameters | |
| ---------- | |
| sr : int | |
| Sample rate of audio | |
| n_fft : int | |
| Number of FFT bins | |
| n_mels : int | |
| Number of mels | |
| fmin : float, optional | |
| Lowest frequency, in Hz, by default 0.0 | |
| fmax : float, optional | |
| Highest frequency, by default None | |
| Returns | |
| ------- | |
| np.ndarray [shape=(n_mels, 1 + n_fft/2)] | |
| Mel transform matrix | |
| """ | |
| from librosa.filters import mel as librosa_mel_fn | |
| return librosa_mel_fn( | |
| sr=sr, | |
| n_fft=n_fft, | |
| n_mels=n_mels, | |
| fmin=fmin, | |
| fmax=fmax, | |
| ) | |
| def mel_spectrogram( | |
| self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs | |
| ): | |
| """Computes a Mel spectrogram. | |
| Parameters | |
| ---------- | |
| n_mels : int, optional | |
| Number of mels, by default 80 | |
| mel_fmin : float, optional | |
| Lowest frequency, in Hz, by default 0.0 | |
| mel_fmax : float, optional | |
| Highest frequency, by default None | |
| kwargs : dict, optional | |
| Keyword arguments to self.stft(). | |
| Returns | |
| ------- | |
| torch.Tensor [shape=(batch, channels, mels, time)] | |
| Mel spectrogram. | |
| """ | |
| stft = self.stft(**kwargs) | |
| magnitude = torch.abs(stft) | |
| nf = magnitude.shape[2] | |
| mel_basis = self.get_mel_filters( | |
| sr=self.sample_rate, | |
| n_fft=2 * (nf - 1), | |
| n_mels=n_mels, | |
| fmin=mel_fmin, | |
| fmax=mel_fmax, | |
| ) | |
| mel_basis = torch.from_numpy(mel_basis).to(self.device) | |
| mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T | |
| mel_spectrogram = mel_spectrogram.transpose(-1, 2) | |
| return mel_spectrogram | |
| def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None): | |
| """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``), | |
| it can be normalized depending on norm. For more information about dct: | |
| http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II | |
| Parameters | |
| ---------- | |
| n_mfcc : int | |
| Number of mfccs | |
| n_mels : int | |
| Number of mels | |
| norm : str | |
| Use "ortho" to get a orthogonal matrix or None, by default "ortho" | |
| device : str, optional | |
| Device to load the transformation matrix on, by default None | |
| Returns | |
| ------- | |
| torch.Tensor [shape=(n_mels, n_mfcc)] T | |
| The dct transformation matrix. | |
| """ | |
| from torchaudio.functional import create_dct | |
| return create_dct(n_mfcc, n_mels, norm).to(device) | |
| def mfcc( | |
| self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs | |
| ): | |
| """Computes mel-frequency cepstral coefficients (MFCCs). | |
| Parameters | |
| ---------- | |
| n_mfcc : int, optional | |
| Number of mels, by default 40 | |
| n_mels : int, optional | |
| Number of mels, by default 80 | |
| log_offset: float, optional | |
| Small value to prevent numerical issues when trying to compute log(0), by default 1e-6 | |
| kwargs : dict, optional | |
| Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft() | |
| Returns | |
| ------- | |
| torch.Tensor [shape=(batch, channels, mfccs, time)] | |
| MFCCs. | |
| """ | |
| mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs) | |
| mel_spectrogram = torch.log(mel_spectrogram + log_offset) | |
| dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device) | |
| mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat | |
| mfcc = mfcc.transpose(-1, -2) | |
| return mfcc | |
| def magnitude(self): | |
| """Computes and returns the absolute value of the STFT, which | |
| is the magnitude. This value can also be set to some tensor. | |
| When set, ``self.stft_data`` is manipulated so that its magnitude | |
| matches what this is set to, and modulated by the phase. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Magnitude of STFT. | |
| Examples | |
| -------- | |
| >>> signal = AudioSignal(torch.randn(44100), 44100) | |
| >>> magnitude = signal.magnitude # Computes stft if not computed | |
| >>> magnitude[magnitude < magnitude.mean()] = 0 | |
| >>> signal.magnitude = magnitude | |
| >>> signal.istft() | |
| """ | |
| if self.stft_data is None: | |
| self.stft() | |
| return torch.abs(self.stft_data) | |
| def magnitude(self, value): | |
| self.stft_data = value * torch.exp(1j * self.phase) | |
| return | |
| def log_magnitude( | |
| self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0 | |
| ): | |
| """Computes the log-magnitude of the spectrogram. | |
| Parameters | |
| ---------- | |
| ref_value : float, optional | |
| The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``. | |
| Zeros in the output correspond to positions where ``S == ref``, | |
| by default 1.0 | |
| amin : float, optional | |
| Minimum threshold for ``S`` and ``ref``, by default 1e-5 | |
| top_db : float, optional | |
| Threshold the output at ``top_db`` below the peak: | |
| ``max(10 * log10(S/ref)) - top_db``, by default -80.0 | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Log-magnitude spectrogram | |
| """ | |
| magnitude = self.magnitude | |
| amin = amin**2 | |
| log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin)) | |
| log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) | |
| if top_db is not None: | |
| log_spec = torch.maximum(log_spec, log_spec.max() - top_db) | |
| return log_spec | |
| def phase(self): | |
| """Computes and returns the phase of the STFT. | |
| This value can also be set to some tensor. | |
| When set, ``self.stft_data`` is manipulated so that its phase | |
| matches what this is set to, we original magnitudeith th. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Phase of STFT. | |
| Examples | |
| -------- | |
| >>> signal = AudioSignal(torch.randn(44100), 44100) | |
| >>> phase = signal.phase # Computes stft if not computed | |
| >>> phase[phase < phase.mean()] = 0 | |
| >>> signal.phase = phase | |
| >>> signal.istft() | |
| """ | |
| if self.stft_data is None: | |
| self.stft() | |
| return torch.angle(self.stft_data) | |
| def phase(self, value): | |
| self.stft_data = self.magnitude * torch.exp(1j * value) | |
| return | |
| # Operator overloading | |
| def __add__(self, other): | |
| new_signal = self.clone() | |
| new_signal.audio_data += util._get_value(other) | |
| return new_signal | |
| def __iadd__(self, other): | |
| self.audio_data += util._get_value(other) | |
| return self | |
| def __radd__(self, other): | |
| return self + other | |
| def __sub__(self, other): | |
| new_signal = self.clone() | |
| new_signal.audio_data -= util._get_value(other) | |
| return new_signal | |
| def __isub__(self, other): | |
| self.audio_data -= util._get_value(other) | |
| return self | |
| def __mul__(self, other): | |
| new_signal = self.clone() | |
| new_signal.audio_data *= util._get_value(other) | |
| return new_signal | |
| def __imul__(self, other): | |
| self.audio_data *= util._get_value(other) | |
| return self | |
| def __rmul__(self, other): | |
| return self * other | |
| # Representation | |
| def _info(self): | |
| dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]" | |
| info = { | |
| "duration": f"{dur} seconds", | |
| "batch_size": self.batch_size, | |
| "path": self.path_to_file if self.path_to_file else "path unknown", | |
| "sample_rate": self.sample_rate, | |
| "num_channels": self.num_channels if self.num_channels else "[unknown]", | |
| "audio_data.shape": self.audio_data.shape, | |
| "stft_params": self.stft_params, | |
| "device": self.device, | |
| } | |
| return info | |
| def markdown(self): | |
| """Produces a markdown representation of AudioSignal, in a markdown table. | |
| Returns | |
| ------- | |
| str | |
| Markdown representation of AudioSignal. | |
| Examples | |
| -------- | |
| >>> signal = AudioSignal(torch.randn(44100), 44100) | |
| >>> print(signal.markdown()) | |
| | Key | Value | |
| |---|--- | |
| | duration | 1.000 seconds | | |
| | batch_size | 1 | | |
| | path | path unknown | | |
| | sample_rate | 44100 | | |
| | num_channels | 1 | | |
| | audio_data.shape | torch.Size([1, 1, 44100]) | | |
| | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) | | |
| | device | cpu | | |
| """ | |
| info = self._info() | |
| FORMAT = "| Key | Value \n" "|---|--- \n" | |
| for k, v in info.items(): | |
| row = f"| {k} | {v} |\n" | |
| FORMAT += row | |
| return FORMAT | |
| def __str__(self): | |
| info = self._info() | |
| desc = "" | |
| for k, v in info.items(): | |
| desc += f"{k}: {v}\n" | |
| return desc | |
| def __rich__(self): | |
| from rich.table import Table | |
| info = self._info() | |
| table = Table(title=f"{self.__class__.__name__}") | |
| table.add_column("Key", style="green") | |
| table.add_column("Value", style="cyan") | |
| for k, v in info.items(): | |
| table.add_row(k, str(v)) | |
| return table | |
| # Comparison | |
| def __eq__(self, other): | |
| for k, v in list(self.__dict__.items()): | |
| if torch.is_tensor(v): | |
| if not torch.allclose(v, other.__dict__[k], atol=1e-6): | |
| max_error = (v - other.__dict__[k]).abs().max() | |
| print(f"Max abs error for {k}: {max_error}") | |
| return False | |
| return True | |
| # Indexing | |
| def __getitem__(self, key): | |
| if torch.is_tensor(key) and key.ndim == 0 and key.item() is True: | |
| assert self.batch_size == 1 | |
| audio_data = self.audio_data | |
| _loudness = self._loudness | |
| stft_data = self.stft_data | |
| elif isinstance(key, (bool, int, list, slice, tuple)) or ( | |
| torch.is_tensor(key) and key.ndim <= 1 | |
| ): | |
| # Indexing only on the batch dimension. | |
| # Then let's copy over relevant stuff. | |
| # Future work: make this work for time-indexing | |
| # as well, using the hop length. | |
| audio_data = self.audio_data[key] | |
| _loudness = self._loudness[key] if self._loudness is not None else None | |
| stft_data = self.stft_data[key] if self.stft_data is not None else None | |
| sources = None | |
| copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params) | |
| copy._loudness = _loudness | |
| copy._stft_data = stft_data | |
| copy.sources = sources | |
| return copy | |
| def __setitem__(self, key, value): | |
| if not isinstance(value, type(self)): | |
| self.audio_data[key] = value | |
| return | |
| if torch.is_tensor(key) and key.ndim == 0 and key.item() is True: | |
| assert self.batch_size == 1 | |
| self.audio_data = value.audio_data | |
| self._loudness = value._loudness | |
| self.stft_data = value.stft_data | |
| return | |
| elif isinstance(key, (bool, int, list, slice, tuple)) or ( | |
| torch.is_tensor(key) and key.ndim <= 1 | |
| ): | |
| if self.audio_data is not None and value.audio_data is not None: | |
| self.audio_data[key] = value.audio_data | |
| if self._loudness is not None and value._loudness is not None: | |
| self._loudness[key] = value._loudness | |
| if self.stft_data is not None and value.stft_data is not None: | |
| self.stft_data[key] = value.stft_data | |
| return | |
| def __ne__(self, other): | |
| return not self == other | |