Alexandre Défossez
		
	commited on
		
		
					Improve demo (#51)
Browse files* allowing sharing directly, changelog, reduce volume.
* activate
* plop
- CHANGELOG.md +11 -2
- README.md +1 -1
- app.py +11 -9
- app_batched.py +3 -1
- audiocraft/__init__.py +1 -1
- audiocraft/data/audio.py +3 -1
- audiocraft/data/audio_utils.py +9 -4
- audiocraft/models/musicgen.py +2 -0
- audiocraft/modules/conditioners.py +6 -2
    	
        CHANGELOG.md
    CHANGED
    
    | @@ -4,6 +4,15 @@ All notable changes to this project will be documented in this file. | |
| 4 |  | 
| 5 | 
             
            The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
         | 
| 6 |  | 
| 7 | 
            -
            ## [0.0. | 
| 8 |  | 
| 9 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 4 |  | 
| 5 | 
             
            The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
         | 
| 6 |  | 
| 7 | 
            +
            ## [0.0.2a] - TBD
         | 
| 8 |  | 
| 9 | 
            +
            Improved demo, fixed top p (thanks @jnordberg).
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Compressor tanh on output to avoid clipping with some style (especially piano).
         | 
| 12 | 
            +
            Now repeating the conditioning periodically if it is too short.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            More options when launching Gradio app locally (thanks @ashleykleynhans).
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            ## [0.0.1] - 2023-06-09
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            Initial release, with model evaluation only.
         | 
    	
        README.md
    CHANGED
    
    | @@ -80,7 +80,7 @@ wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), s | |
| 80 |  | 
| 81 | 
             
            for idx, one_wav in enumerate(wav):
         | 
| 82 | 
             
                # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
         | 
| 83 | 
            -
                audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness")
         | 
| 84 | 
             
            ```
         | 
| 85 |  | 
| 86 |  | 
|  | |
| 80 |  | 
| 81 | 
             
            for idx, one_wav in enumerate(wav):
         | 
| 82 | 
             
                # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
         | 
| 83 | 
            +
                audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
         | 
| 84 | 
             
            ```
         | 
| 85 |  | 
| 86 |  | 
    	
        app.py
    CHANGED
    
    | @@ -13,7 +13,6 @@ import gradio as gr | |
| 13 | 
             
            from audiocraft.models import MusicGen
         | 
| 14 | 
             
            from audiocraft.data.audio import audio_write
         | 
| 15 |  | 
| 16 | 
            -
             | 
| 17 | 
             
            MODEL = None
         | 
| 18 |  | 
| 19 |  | 
| @@ -56,7 +55,9 @@ def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef): | |
| 56 |  | 
| 57 | 
             
                output = output.detach().cpu().float()[0]
         | 
| 58 | 
             
                with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 59 | 
            -
                    audio_write( | 
|  | |
|  | |
| 60 | 
             
                    waveform_video = gr.make_waveform(file.name)
         | 
| 61 | 
             
                return waveform_video
         | 
| 62 |  | 
| @@ -66,7 +67,7 @@ def ui(**kwargs): | |
| 66 | 
             
                    gr.Markdown(
         | 
| 67 | 
             
                        """
         | 
| 68 | 
             
                        # MusicGen
         | 
| 69 | 
            -
             | 
| 70 | 
             
                        This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
         | 
| 71 | 
             
                        presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
         | 
| 72 | 
             
                        <br/>
         | 
| @@ -129,19 +130,19 @@ def ui(**kwargs): | |
| 129 | 
             
                    gr.Markdown(
         | 
| 130 | 
             
                        """
         | 
| 131 | 
             
                        ### More details
         | 
| 132 | 
            -
             | 
| 133 | 
             
                        The model will generate a short music extract based on the description you provided.
         | 
| 134 | 
             
                        You can generate up to 30 seconds of audio.
         | 
| 135 | 
            -
             | 
| 136 | 
             
                        We present 4 model variations:
         | 
| 137 | 
             
                        1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
         | 
| 138 | 
             
                        2. Small -- a 300M transformer decoder conditioned on text only.
         | 
| 139 | 
             
                        3. Medium -- a 1.5B transformer decoder conditioned on text only.
         | 
| 140 | 
             
                        4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
         | 
| 141 | 
            -
             | 
| 142 | 
             
                        When using `melody`, ou can optionaly provide a reference audio from
         | 
| 143 | 
             
                        which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
         | 
| 144 | 
            -
             | 
| 145 | 
             
                        You can also use your own GPU or a Google Colab by following the instructions on our repo.
         | 
| 146 | 
             
                        See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
         | 
| 147 | 
             
                        for more details.
         | 
| @@ -168,7 +169,8 @@ def ui(**kwargs): | |
| 168 | 
             
                    if share:
         | 
| 169 | 
             
                        launch_kwargs['share'] = share
         | 
| 170 |  | 
| 171 | 
            -
                    interface.launch(**launch_kwargs)
         | 
|  | |
| 172 |  | 
| 173 | 
             
            if __name__ == "__main__":
         | 
| 174 | 
             
                # torch.cuda.set_per_process_memory_fraction(0.48)
         | 
| @@ -207,4 +209,4 @@ if __name__ == "__main__": | |
| 207 | 
             
                    server_port=args.server_port,
         | 
| 208 | 
             
                    share=args.share,
         | 
| 209 | 
             
                    listen=args.listen
         | 
| 210 | 
            -
                )
         | 
|  | |
| 13 | 
             
            from audiocraft.models import MusicGen
         | 
| 14 | 
             
            from audiocraft.data.audio import audio_write
         | 
| 15 |  | 
|  | |
| 16 | 
             
            MODEL = None
         | 
| 17 |  | 
| 18 |  | 
|  | |
| 55 |  | 
| 56 | 
             
                output = output.detach().cpu().float()[0]
         | 
| 57 | 
             
                with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 58 | 
            +
                    audio_write(
         | 
| 59 | 
            +
                        file.name, output, MODEL.sample_rate, strategy="loudness",
         | 
| 60 | 
            +
                        loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
         | 
| 61 | 
             
                    waveform_video = gr.make_waveform(file.name)
         | 
| 62 | 
             
                return waveform_video
         | 
| 63 |  | 
|  | |
| 67 | 
             
                    gr.Markdown(
         | 
| 68 | 
             
                        """
         | 
| 69 | 
             
                        # MusicGen
         | 
| 70 | 
            +
             | 
| 71 | 
             
                        This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
         | 
| 72 | 
             
                        presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
         | 
| 73 | 
             
                        <br/>
         | 
|  | |
| 130 | 
             
                    gr.Markdown(
         | 
| 131 | 
             
                        """
         | 
| 132 | 
             
                        ### More details
         | 
| 133 | 
            +
             | 
| 134 | 
             
                        The model will generate a short music extract based on the description you provided.
         | 
| 135 | 
             
                        You can generate up to 30 seconds of audio.
         | 
| 136 | 
            +
             | 
| 137 | 
             
                        We present 4 model variations:
         | 
| 138 | 
             
                        1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
         | 
| 139 | 
             
                        2. Small -- a 300M transformer decoder conditioned on text only.
         | 
| 140 | 
             
                        3. Medium -- a 1.5B transformer decoder conditioned on text only.
         | 
| 141 | 
             
                        4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
         | 
| 142 | 
            +
             | 
| 143 | 
             
                        When using `melody`, ou can optionaly provide a reference audio from
         | 
| 144 | 
             
                        which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
         | 
| 145 | 
            +
             | 
| 146 | 
             
                        You can also use your own GPU or a Google Colab by following the instructions on our repo.
         | 
| 147 | 
             
                        See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
         | 
| 148 | 
             
                        for more details.
         | 
|  | |
| 169 | 
             
                    if share:
         | 
| 170 | 
             
                        launch_kwargs['share'] = share
         | 
| 171 |  | 
| 172 | 
            +
                    interface.queue().launch(**launch_kwargs, max_threads=1)
         | 
| 173 | 
            +
             | 
| 174 |  | 
| 175 | 
             
            if __name__ == "__main__":
         | 
| 176 | 
             
                # torch.cuda.set_per_process_memory_fraction(0.48)
         | 
|  | |
| 209 | 
             
                    server_port=args.server_port,
         | 
| 210 | 
             
                    share=args.share,
         | 
| 211 | 
             
                    listen=args.listen
         | 
| 212 | 
            +
                )
         | 
    	
        app_batched.py
    CHANGED
    
    | @@ -57,7 +57,9 @@ def predict(texts, melodies): | |
| 57 | 
             
                out_files = []
         | 
| 58 | 
             
                for output in outputs:
         | 
| 59 | 
             
                    with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 60 | 
            -
                        audio_write( | 
|  | |
|  | |
| 61 | 
             
                        waveform_video = gr.make_waveform(file.name)
         | 
| 62 | 
             
                        out_files.append(waveform_video)
         | 
| 63 | 
             
                return [out_files]
         | 
|  | |
| 57 | 
             
                out_files = []
         | 
| 58 | 
             
                for output in outputs:
         | 
| 59 | 
             
                    with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 60 | 
            +
                        audio_write(
         | 
| 61 | 
            +
                            file.name, output, MODEL.sample_rate, strategy="loudness",
         | 
| 62 | 
            +
                            loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
         | 
| 63 | 
             
                        waveform_video = gr.make_waveform(file.name)
         | 
| 64 | 
             
                        out_files.append(waveform_video)
         | 
| 65 | 
             
                return [out_files]
         | 
    	
        audiocraft/__init__.py
    CHANGED
    
    | @@ -7,4 +7,4 @@ | |
| 7 | 
             
            # flake8: noqa
         | 
| 8 | 
             
            from . import data, modules, models
         | 
| 9 |  | 
| 10 | 
            -
            __version__ = '0.0. | 
|  | |
| 7 | 
             
            # flake8: noqa
         | 
| 8 | 
             
            from . import data, modules, models
         | 
| 9 |  | 
| 10 | 
            +
            __version__ = '0.0.2a1'
         | 
    	
        audiocraft/data/audio.py
    CHANGED
    
    | @@ -155,6 +155,7 @@ def audio_write(stem_name: tp.Union[str, Path], | |
| 155 | 
             
                            format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
         | 
| 156 | 
             
                            strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 157 | 
             
                            rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
|  | |
| 158 | 
             
                            log_clipping: bool = True, make_parent_dir: bool = True,
         | 
| 159 | 
             
                            add_suffix: bool = True) -> Path:
         | 
| 160 | 
             
                """Convenience function for saving audio to disk. Returns the filename the audio was written to.
         | 
| @@ -173,7 +174,8 @@ def audio_write(stem_name: tp.Union[str, Path], | |
| 173 | 
             
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 174 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 175 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 176 | 
            -
                     | 
|  | |
| 177 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 178 | 
             
                    make_parent_dir (bool): Make parent directory if it doesn't exist.
         | 
| 179 | 
             
                Returns:
         | 
|  | |
| 155 | 
             
                            format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
         | 
| 156 | 
             
                            strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 157 | 
             
                            rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 158 | 
            +
                            loudness_compressor: bool = False,
         | 
| 159 | 
             
                            log_clipping: bool = True, make_parent_dir: bool = True,
         | 
| 160 | 
             
                            add_suffix: bool = True) -> Path:
         | 
| 161 | 
             
                """Convenience function for saving audio to disk. Returns the filename the audio was written to.
         | 
|  | |
| 174 | 
             
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 175 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 176 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 177 | 
            +
                    loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
         | 
| 178 | 
            +
                     when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 179 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 180 | 
             
                    make_parent_dir (bool): Make parent directory if it doesn't exist.
         | 
| 181 | 
             
                Returns:
         | 
    	
        audiocraft/data/audio_utils.py
    CHANGED
    
    | @@ -54,8 +54,8 @@ def convert_audio(wav: torch.Tensor, from_rate: float, | |
| 54 | 
             
                return wav
         | 
| 55 |  | 
| 56 |  | 
| 57 | 
            -
            def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float =  | 
| 58 | 
            -
                                   energy_floor: float = 2e-3):
         | 
| 59 | 
             
                """Normalize an input signal to a user loudness in dB LKFS.
         | 
| 60 | 
             
                Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
         | 
| 61 |  | 
| @@ -63,6 +63,7 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db | |
| 63 | 
             
                    wav (torch.Tensor): Input multichannel audio data.
         | 
| 64 | 
             
                    sample_rate (int): Sample rate.
         | 
| 65 | 
             
                    loudness_headroom_db (float): Target loudness of the output in dB LUFS.
         | 
|  | |
| 66 | 
             
                    energy_floor (float): anything below that RMS level will not be rescaled.
         | 
| 67 | 
             
                Returns:
         | 
| 68 | 
             
                    output (torch.Tensor): Loudness normalized output data.
         | 
| @@ -76,6 +77,8 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db | |
| 76 | 
             
                delta_loudness = -loudness_headroom_db - input_loudness_db
         | 
| 77 | 
             
                gain = 10.0 ** (delta_loudness / 20.0)
         | 
| 78 | 
             
                output = gain * wav
         | 
|  | |
|  | |
| 79 | 
             
                assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
         | 
| 80 | 
             
                return output
         | 
| 81 |  | 
| @@ -93,7 +96,8 @@ def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optio | |
| 93 | 
             
            def normalize_audio(wav: torch.Tensor, normalize: bool = True,
         | 
| 94 | 
             
                                strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 95 | 
             
                                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 96 | 
            -
                                 | 
|  | |
| 97 | 
             
                                stem_name: tp.Optional[str] = None) -> torch.Tensor:
         | 
| 98 | 
             
                """Normalize the audio according to the prescribed strategy (see after).
         | 
| 99 |  | 
| @@ -109,6 +113,7 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True, | |
| 109 | 
             
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 110 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 111 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
|  | |
| 112 | 
             
                    log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 113 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 114 | 
             
                    sample_rate (int): Sample rate for the audio data (required for loudness).
         | 
| @@ -132,7 +137,7 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True, | |
| 132 | 
             
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 133 | 
             
                elif strategy == 'loudness':
         | 
| 134 | 
             
                    assert sample_rate is not None, "Loudness normalization requires sample rate."
         | 
| 135 | 
            -
                    wav = normalize_loudness(wav, sample_rate, loudness_headroom_db)
         | 
| 136 | 
             
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 137 | 
             
                else:
         | 
| 138 | 
             
                    assert wav.abs().max() < 1
         | 
|  | |
| 54 | 
             
                return wav
         | 
| 55 |  | 
| 56 |  | 
| 57 | 
            +
            def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
         | 
| 58 | 
            +
                                   loudness_compressor: bool = False, energy_floor: float = 2e-3):
         | 
| 59 | 
             
                """Normalize an input signal to a user loudness in dB LKFS.
         | 
| 60 | 
             
                Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
         | 
| 61 |  | 
|  | |
| 63 | 
             
                    wav (torch.Tensor): Input multichannel audio data.
         | 
| 64 | 
             
                    sample_rate (int): Sample rate.
         | 
| 65 | 
             
                    loudness_headroom_db (float): Target loudness of the output in dB LUFS.
         | 
| 66 | 
            +
                    loudness_compressor (bool): Uses tanh for soft clipping.
         | 
| 67 | 
             
                    energy_floor (float): anything below that RMS level will not be rescaled.
         | 
| 68 | 
             
                Returns:
         | 
| 69 | 
             
                    output (torch.Tensor): Loudness normalized output data.
         | 
|  | |
| 77 | 
             
                delta_loudness = -loudness_headroom_db - input_loudness_db
         | 
| 78 | 
             
                gain = 10.0 ** (delta_loudness / 20.0)
         | 
| 79 | 
             
                output = gain * wav
         | 
| 80 | 
            +
                if loudness_compressor:
         | 
| 81 | 
            +
                    output = torch.tanh(output)
         | 
| 82 | 
             
                assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
         | 
| 83 | 
             
                return output
         | 
| 84 |  | 
|  | |
| 96 | 
             
            def normalize_audio(wav: torch.Tensor, normalize: bool = True,
         | 
| 97 | 
             
                                strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 98 | 
             
                                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 99 | 
            +
                                loudness_compressor: bool = False, log_clipping: bool = False,
         | 
| 100 | 
            +
                                sample_rate: tp.Optional[int] = None,
         | 
| 101 | 
             
                                stem_name: tp.Optional[str] = None) -> torch.Tensor:
         | 
| 102 | 
             
                """Normalize the audio according to the prescribed strategy (see after).
         | 
| 103 |  | 
|  | |
| 113 | 
             
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 114 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 115 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 116 | 
            +
                    loudness_compressor (bool): If True, uses tanh based soft clipping.
         | 
| 117 | 
             
                    log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 118 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 119 | 
             
                    sample_rate (int): Sample rate for the audio data (required for loudness).
         | 
|  | |
| 137 | 
             
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 138 | 
             
                elif strategy == 'loudness':
         | 
| 139 | 
             
                    assert sample_rate is not None, "Loudness normalization requires sample rate."
         | 
| 140 | 
            +
                    wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
         | 
| 141 | 
             
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 142 | 
             
                else:
         | 
| 143 | 
             
                    assert wav.abs().max() < 1
         | 
    	
        audiocraft/models/musicgen.py
    CHANGED
    
    | @@ -88,6 +88,8 @@ class MusicGen: | |
| 88 | 
             
                    cache_dir = os.environ.get('MUSICGEN_ROOT', None)
         | 
| 89 | 
             
                    compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
         | 
| 90 | 
             
                    lm = load_lm_model(name, device=device, cache_dir=cache_dir)
         | 
|  | |
|  | |
| 91 |  | 
| 92 | 
             
                    return MusicGen(name, compression_model, lm)
         | 
| 93 |  | 
|  | |
| 88 | 
             
                    cache_dir = os.environ.get('MUSICGEN_ROOT', None)
         | 
| 89 | 
             
                    compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
         | 
| 90 | 
             
                    lm = load_lm_model(name, device=device, cache_dir=cache_dir)
         | 
| 91 | 
            +
                    if name == 'melody' and True:
         | 
| 92 | 
            +
                        lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
         | 
| 93 |  | 
| 94 | 
             
                    return MusicGen(name, compression_model, lm)
         | 
| 95 |  | 
    	
        audiocraft/modules/conditioners.py
    CHANGED
    
    | @@ -9,6 +9,7 @@ from copy import deepcopy | |
| 9 | 
             
            from dataclasses import dataclass, field
         | 
| 10 | 
             
            from itertools import chain
         | 
| 11 | 
             
            import logging
         | 
|  | |
| 12 | 
             
            import random
         | 
| 13 | 
             
            import re
         | 
| 14 | 
             
            import typing as tp
         | 
| @@ -484,7 +485,7 @@ class ChromaStemConditioner(WaveformConditioner): | |
| 484 | 
             
                    **kwargs: Additional parameters for the chroma extractor.
         | 
| 485 | 
             
                """
         | 
| 486 | 
             
                def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
         | 
| 487 | 
            -
                             duration: float, match_len_on_eval: bool =  | 
| 488 | 
             
                             n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
         | 
| 489 | 
             
                    from demucs import pretrained
         | 
| 490 | 
             
                    super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
         | 
| @@ -535,7 +536,10 @@ class ChromaStemConditioner(WaveformConditioner): | |
| 535 | 
             
                            chroma = chroma[:, :self.chroma_len]
         | 
| 536 | 
             
                            logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
         | 
| 537 | 
             
                        elif t < self.chroma_len:
         | 
| 538 | 
            -
                            chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
         | 
|  | |
|  | |
|  | |
| 539 | 
             
                            logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
         | 
| 540 | 
             
                    return chroma
         | 
| 541 |  | 
|  | |
| 9 | 
             
            from dataclasses import dataclass, field
         | 
| 10 | 
             
            from itertools import chain
         | 
| 11 | 
             
            import logging
         | 
| 12 | 
            +
            import math
         | 
| 13 | 
             
            import random
         | 
| 14 | 
             
            import re
         | 
| 15 | 
             
            import typing as tp
         | 
|  | |
| 485 | 
             
                    **kwargs: Additional parameters for the chroma extractor.
         | 
| 486 | 
             
                """
         | 
| 487 | 
             
                def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
         | 
| 488 | 
            +
                             duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
         | 
| 489 | 
             
                             n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
         | 
| 490 | 
             
                    from demucs import pretrained
         | 
| 491 | 
             
                    super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
         | 
|  | |
| 536 | 
             
                            chroma = chroma[:, :self.chroma_len]
         | 
| 537 | 
             
                            logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
         | 
| 538 | 
             
                        elif t < self.chroma_len:
         | 
| 539 | 
            +
                            # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
         | 
| 540 | 
            +
                            n_repeat = int(math.ceil(self.chroma_len / t))
         | 
| 541 | 
            +
                            chroma = chroma.repeat(1, n_repeat, 1)
         | 
| 542 | 
            +
                            chroma = chroma[:, :self.chroma_len]
         | 
| 543 | 
             
                            logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
         | 
| 544 | 
             
                    return chroma
         | 
| 545 |  | 
