Spaces:
Running
on
Zero
Running
on
Zero
| import inspect | |
| import typing | |
| from functools import wraps | |
| from . import util | |
| def format_figure(func): | |
| """Decorator for formatting figures produced by the code below. | |
| See :py:func:`audiotools.core.util.format_figure` for more. | |
| Parameters | |
| ---------- | |
| func : Callable | |
| Plotting function that is decorated by this function. | |
| """ | |
| def wrapper(*args, **kwargs): | |
| f_keys = inspect.signature(util.format_figure).parameters.keys() | |
| f_kwargs = {} | |
| for k, v in list(kwargs.items()): | |
| if k in f_keys: | |
| kwargs.pop(k) | |
| f_kwargs[k] = v | |
| func(*args, **kwargs) | |
| util.format_figure(**f_kwargs) | |
| return wrapper | |
| class DisplayMixin: | |
| def specshow( | |
| self, | |
| preemphasis: bool = False, | |
| x_axis: str = "time", | |
| y_axis: str = "linear", | |
| n_mels: int = 128, | |
| **kwargs, | |
| ): | |
| """Displays a spectrogram, using ``librosa.display.specshow``. | |
| Parameters | |
| ---------- | |
| preemphasis : bool, optional | |
| Whether or not to apply preemphasis, which makes high | |
| frequency detail easier to see, by default False | |
| x_axis : str, optional | |
| How to label the x axis, by default "time" | |
| y_axis : str, optional | |
| How to label the y axis, by default "linear" | |
| n_mels : int, optional | |
| If displaying a mel spectrogram with ``y_axis = "mel"``, | |
| this controls the number of mels, by default 128. | |
| kwargs : dict, optional | |
| Keyword arguments to :py:func:`audiotools.core.util.format_figure`. | |
| """ | |
| import librosa | |
| import librosa.display | |
| # Always re-compute the STFT data before showing it, in case | |
| # it changed. | |
| signal = self.clone() | |
| signal.stft_data = None | |
| if preemphasis: | |
| signal.preemphasis() | |
| ref = signal.magnitude.max() | |
| log_mag = signal.log_magnitude(ref_value=ref) | |
| if y_axis == "mel": | |
| log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10() | |
| log_mag -= log_mag.max() | |
| librosa.display.specshow( | |
| log_mag.numpy()[0].mean(axis=0), | |
| x_axis=x_axis, | |
| y_axis=y_axis, | |
| sr=signal.sample_rate, | |
| **kwargs, | |
| ) | |
| def waveplot(self, x_axis: str = "time", **kwargs): | |
| """Displays a waveform plot, using ``librosa.display.waveshow``. | |
| Parameters | |
| ---------- | |
| x_axis : str, optional | |
| How to label the x axis, by default "time" | |
| kwargs : dict, optional | |
| Keyword arguments to :py:func:`audiotools.core.util.format_figure`. | |
| """ | |
| import librosa | |
| import librosa.display | |
| audio_data = self.audio_data[0].mean(dim=0) | |
| audio_data = audio_data.cpu().numpy() | |
| plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot" | |
| wave_plot_fn = getattr(librosa.display, plot_fn) | |
| wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs) | |
| def wavespec(self, x_axis: str = "time", **kwargs): | |
| """Displays a waveform plot, using ``librosa.display.waveshow``. | |
| Parameters | |
| ---------- | |
| x_axis : str, optional | |
| How to label the x axis, by default "time" | |
| kwargs : dict, optional | |
| Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`. | |
| """ | |
| import matplotlib.pyplot as plt | |
| from matplotlib.gridspec import GridSpec | |
| gs = GridSpec(6, 1) | |
| plt.subplot(gs[0, :]) | |
| self.waveplot(x_axis=x_axis) | |
| plt.subplot(gs[1:, :]) | |
| self.specshow(x_axis=x_axis, **kwargs) | |
| def write_audio_to_tb( | |
| self, | |
| tag: str, | |
| writer, | |
| step: int = None, | |
| plot_fn: typing.Union[typing.Callable, str] = "specshow", | |
| **kwargs, | |
| ): | |
| """Writes a signal and its spectrogram to Tensorboard. Will show up | |
| under the Audio and Images tab in Tensorboard. | |
| Parameters | |
| ---------- | |
| tag : str | |
| Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be | |
| written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``). | |
| writer : SummaryWriter | |
| A SummaryWriter object from PyTorch library. | |
| step : int, optional | |
| The step to write the signal to, by default None | |
| plot_fn : typing.Union[typing.Callable, str], optional | |
| How to create the image. Set to ``None`` to avoid plotting, by default "specshow" | |
| kwargs : dict, optional | |
| Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or | |
| whatever ``plot_fn`` is set to. | |
| """ | |
| import matplotlib.pyplot as plt | |
| audio_data = self.audio_data[0, 0].detach().cpu() | |
| sample_rate = self.sample_rate | |
| writer.add_audio(tag, audio_data, step, sample_rate) | |
| if plot_fn is not None: | |
| if isinstance(plot_fn, str): | |
| plot_fn = getattr(self, plot_fn) | |
| fig = plt.figure() | |
| plt.clf() | |
| plot_fn(**kwargs) | |
| writer.add_figure(tag.replace("wav", "png"), fig, step) | |
| def save_image( | |
| self, | |
| image_path: str, | |
| plot_fn: typing.Union[typing.Callable, str] = "specshow", | |
| **kwargs, | |
| ): | |
| """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to | |
| a specified file. | |
| Parameters | |
| ---------- | |
| image_path : str | |
| Where to save the file to. | |
| plot_fn : typing.Union[typing.Callable, str], optional | |
| How to create the image. Set to ``None`` to avoid plotting, by default "specshow" | |
| kwargs : dict, optional | |
| Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or | |
| whatever ``plot_fn`` is set to. | |
| """ | |
| import matplotlib.pyplot as plt | |
| if isinstance(plot_fn, str): | |
| plot_fn = getattr(self, plot_fn) | |
| plt.clf() | |
| plot_fn(**kwargs) | |
| plt.savefig(image_path, bbox_inches="tight", pad_inches=0) | |
| plt.close() | |