Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	add encoder
Browse files- encoder/__init__.py +0 -0
 - encoder/__pycache__/__init__.cpython-37.pyc +0 -0
 - encoder/__pycache__/audio.cpython-37.pyc +0 -0
 - encoder/__pycache__/inference.cpython-37.pyc +0 -0
 - encoder/__pycache__/model.cpython-37.pyc +0 -0
 - encoder/__pycache__/params_data.cpython-37.pyc +0 -0
 - encoder/__pycache__/params_model.cpython-37.pyc +0 -0
 - encoder/audio.py +117 -0
 - encoder/config.py +45 -0
 - encoder/data_objects/__init__.py +2 -0
 - encoder/data_objects/random_cycler.py +37 -0
 - encoder/data_objects/speaker.py +40 -0
 - encoder/data_objects/speaker_batch.py +13 -0
 - encoder/data_objects/speaker_verification_dataset.py +56 -0
 - encoder/data_objects/utterance.py +26 -0
 - encoder/inference.py +178 -0
 - encoder/model.py +135 -0
 - encoder/params_data.py +29 -0
 - encoder/params_model.py +11 -0
 - encoder/preprocess.py +184 -0
 - encoder/train.py +125 -0
 - encoder/visualizations.py +179 -0
 
    	
        encoder/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        encoder/__pycache__/__init__.cpython-37.pyc
    ADDED
    
    | 
         Binary file (159 Bytes). View file 
     | 
| 
         | 
    	
        encoder/__pycache__/audio.cpython-37.pyc
    ADDED
    
    | 
         Binary file (3.97 kB). View file 
     | 
| 
         | 
    	
        encoder/__pycache__/inference.cpython-37.pyc
    ADDED
    
    | 
         Binary file (7.17 kB). View file 
     | 
| 
         | 
    	
        encoder/__pycache__/model.cpython-37.pyc
    ADDED
    
    | 
         Binary file (4.77 kB). View file 
     | 
| 
         | 
    	
        encoder/__pycache__/params_data.cpython-37.pyc
    ADDED
    
    | 
         Binary file (466 Bytes). View file 
     | 
| 
         | 
    	
        encoder/__pycache__/params_model.cpython-37.pyc
    ADDED
    
    | 
         Binary file (346 Bytes). View file 
     | 
| 
         | 
    	
        encoder/audio.py
    ADDED
    
    | 
         @@ -0,0 +1,117 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from scipy.ndimage.morphology import binary_dilation
         
     | 
| 2 | 
         
            +
            from encoder.params_data import *
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
            from typing import Optional, Union
         
     | 
| 5 | 
         
            +
            from warnings import warn
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import librosa
         
     | 
| 8 | 
         
            +
            import struct
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            try:
         
     | 
| 11 | 
         
            +
                import webrtcvad
         
     | 
| 12 | 
         
            +
            except:
         
     | 
| 13 | 
         
            +
                warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
         
     | 
| 14 | 
         
            +
                webrtcvad=None
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            int16_max = (2 ** 15) - 1
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
         
     | 
| 20 | 
         
            +
                               source_sr: Optional[int] = None,
         
     | 
| 21 | 
         
            +
                               normalize: Optional[bool] = True,
         
     | 
| 22 | 
         
            +
                               trim_silence: Optional[bool] = True):
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
                Applies the preprocessing operations used in training the Speaker Encoder to a waveform 
         
     | 
| 25 | 
         
            +
                either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not 
         
     | 
| 28 | 
         
            +
                just .wav), either the waveform as a numpy array of floats.
         
     | 
| 29 | 
         
            +
                :param source_sr: if passing an audio waveform, the sampling rate of the waveform before 
         
     | 
| 30 | 
         
            +
                preprocessing. After preprocessing, the waveform's sampling rate will match the data 
         
     | 
| 31 | 
         
            +
                hyperparameters. If passing a filepath, the sampling rate will be automatically detected and 
         
     | 
| 32 | 
         
            +
                this argument will be ignored.
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
                # Load the wav from disk if needed
         
     | 
| 35 | 
         
            +
                if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
         
     | 
| 36 | 
         
            +
                    wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
         
     | 
| 37 | 
         
            +
                else:
         
     | 
| 38 | 
         
            +
                    wav = fpath_or_wav
         
     | 
| 39 | 
         
            +
                
         
     | 
| 40 | 
         
            +
                # Resample the wav if needed
         
     | 
| 41 | 
         
            +
                if source_sr is not None and source_sr != sampling_rate:
         
     | 
| 42 | 
         
            +
                    wav = librosa.resample(wav, source_sr, sampling_rate)
         
     | 
| 43 | 
         
            +
                    
         
     | 
| 44 | 
         
            +
                # Apply the preprocessing: normalize volume and shorten long silences 
         
     | 
| 45 | 
         
            +
                if normalize:
         
     | 
| 46 | 
         
            +
                    wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
         
     | 
| 47 | 
         
            +
                if webrtcvad and trim_silence:
         
     | 
| 48 | 
         
            +
                    wav = trim_long_silences(wav)
         
     | 
| 49 | 
         
            +
                
         
     | 
| 50 | 
         
            +
                return wav
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            def wav_to_mel_spectrogram(wav):
         
     | 
| 54 | 
         
            +
                """
         
     | 
| 55 | 
         
            +
                Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
         
     | 
| 56 | 
         
            +
                Note: this not a log-mel spectrogram.
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 58 | 
         
            +
                frames = librosa.feature.melspectrogram(
         
     | 
| 59 | 
         
            +
                    wav,
         
     | 
| 60 | 
         
            +
                    sampling_rate,
         
     | 
| 61 | 
         
            +
                    n_fft=int(sampling_rate * mel_window_length / 1000),
         
     | 
| 62 | 
         
            +
                    hop_length=int(sampling_rate * mel_window_step / 1000),
         
     | 
| 63 | 
         
            +
                    n_mels=mel_n_channels
         
     | 
| 64 | 
         
            +
                )
         
     | 
| 65 | 
         
            +
                return frames.astype(np.float32).T
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def trim_long_silences(wav):
         
     | 
| 69 | 
         
            +
                """
         
     | 
| 70 | 
         
            +
                Ensures that segments without voice in the waveform remain no longer than a 
         
     | 
| 71 | 
         
            +
                threshold determined by the VAD parameters in params.py.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                :param wav: the raw waveform as a numpy array of floats 
         
     | 
| 74 | 
         
            +
                :return: the same waveform with silences trimmed away (length <= original wav length)
         
     | 
| 75 | 
         
            +
                """
         
     | 
| 76 | 
         
            +
                # Compute the voice detection window size
         
     | 
| 77 | 
         
            +
                samples_per_window = (vad_window_length * sampling_rate) // 1000
         
     | 
| 78 | 
         
            +
                
         
     | 
| 79 | 
         
            +
                # Trim the end of the audio to have a multiple of the window size
         
     | 
| 80 | 
         
            +
                wav = wav[:len(wav) - (len(wav) % samples_per_window)]
         
     | 
| 81 | 
         
            +
                
         
     | 
| 82 | 
         
            +
                # Convert the float waveform to 16-bit mono PCM
         
     | 
| 83 | 
         
            +
                pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
         
     | 
| 84 | 
         
            +
                
         
     | 
| 85 | 
         
            +
                # Perform voice activation detection
         
     | 
| 86 | 
         
            +
                voice_flags = []
         
     | 
| 87 | 
         
            +
                vad = webrtcvad.Vad(mode=3)
         
     | 
| 88 | 
         
            +
                for window_start in range(0, len(wav), samples_per_window):
         
     | 
| 89 | 
         
            +
                    window_end = window_start + samples_per_window
         
     | 
| 90 | 
         
            +
                    voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
         
     | 
| 91 | 
         
            +
                                                     sample_rate=sampling_rate))
         
     | 
| 92 | 
         
            +
                voice_flags = np.array(voice_flags)
         
     | 
| 93 | 
         
            +
                
         
     | 
| 94 | 
         
            +
                # Smooth the voice detection with a moving average
         
     | 
| 95 | 
         
            +
                def moving_average(array, width):
         
     | 
| 96 | 
         
            +
                    array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
         
     | 
| 97 | 
         
            +
                    ret = np.cumsum(array_padded, dtype=float)
         
     | 
| 98 | 
         
            +
                    ret[width:] = ret[width:] - ret[:-width]
         
     | 
| 99 | 
         
            +
                    return ret[width - 1:] / width
         
     | 
| 100 | 
         
            +
                
         
     | 
| 101 | 
         
            +
                audio_mask = moving_average(voice_flags, vad_moving_average_width)
         
     | 
| 102 | 
         
            +
                audio_mask = np.round(audio_mask).astype(np.bool)
         
     | 
| 103 | 
         
            +
                
         
     | 
| 104 | 
         
            +
                # Dilate the voiced regions
         
     | 
| 105 | 
         
            +
                audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
         
     | 
| 106 | 
         
            +
                audio_mask = np.repeat(audio_mask, samples_per_window)
         
     | 
| 107 | 
         
            +
                
         
     | 
| 108 | 
         
            +
                return wav[audio_mask == True]
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
            def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
         
     | 
| 112 | 
         
            +
                if increase_only and decrease_only:
         
     | 
| 113 | 
         
            +
                    raise ValueError("Both increase only and decrease only are set")
         
     | 
| 114 | 
         
            +
                dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
         
     | 
| 115 | 
         
            +
                if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
         
     | 
| 116 | 
         
            +
                    return wav
         
     | 
| 117 | 
         
            +
                return wav * (10 ** (dBFS_change / 20))
         
     | 
    	
        encoder/config.py
    ADDED
    
    | 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            librispeech_datasets = {
         
     | 
| 2 | 
         
            +
                "train": {
         
     | 
| 3 | 
         
            +
                    "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
         
     | 
| 4 | 
         
            +
                    "other": ["LibriSpeech/train-other-500"]
         
     | 
| 5 | 
         
            +
                },
         
     | 
| 6 | 
         
            +
                "test": {
         
     | 
| 7 | 
         
            +
                    "clean": ["LibriSpeech/test-clean"],
         
     | 
| 8 | 
         
            +
                    "other": ["LibriSpeech/test-other"]
         
     | 
| 9 | 
         
            +
                },
         
     | 
| 10 | 
         
            +
                "dev": {
         
     | 
| 11 | 
         
            +
                    "clean": ["LibriSpeech/dev-clean"],
         
     | 
| 12 | 
         
            +
                    "other": ["LibriSpeech/dev-other"]
         
     | 
| 13 | 
         
            +
                },
         
     | 
| 14 | 
         
            +
            }
         
     | 
| 15 | 
         
            +
            libritts_datasets = {
         
     | 
| 16 | 
         
            +
                "train": {
         
     | 
| 17 | 
         
            +
                    "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
         
     | 
| 18 | 
         
            +
                    "other": ["LibriTTS/train-other-500"]
         
     | 
| 19 | 
         
            +
                },
         
     | 
| 20 | 
         
            +
                "test": {
         
     | 
| 21 | 
         
            +
                    "clean": ["LibriTTS/test-clean"],
         
     | 
| 22 | 
         
            +
                    "other": ["LibriTTS/test-other"]
         
     | 
| 23 | 
         
            +
                },
         
     | 
| 24 | 
         
            +
                "dev": {
         
     | 
| 25 | 
         
            +
                    "clean": ["LibriTTS/dev-clean"],
         
     | 
| 26 | 
         
            +
                    "other": ["LibriTTS/dev-other"]
         
     | 
| 27 | 
         
            +
                },
         
     | 
| 28 | 
         
            +
            }
         
     | 
| 29 | 
         
            +
            voxceleb_datasets = {
         
     | 
| 30 | 
         
            +
                "voxceleb1" : {
         
     | 
| 31 | 
         
            +
                    "train": ["VoxCeleb1/wav"],
         
     | 
| 32 | 
         
            +
                    "test": ["VoxCeleb1/test_wav"]
         
     | 
| 33 | 
         
            +
                },
         
     | 
| 34 | 
         
            +
                "voxceleb2" : {
         
     | 
| 35 | 
         
            +
                    "train": ["VoxCeleb2/dev/aac"],
         
     | 
| 36 | 
         
            +
                    "test": ["VoxCeleb2/test_wav"]
         
     | 
| 37 | 
         
            +
                }
         
     | 
| 38 | 
         
            +
            }
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            other_datasets = [
         
     | 
| 41 | 
         
            +
                "LJSpeech-1.1",
         
     | 
| 42 | 
         
            +
                "VCTK-Corpus/wav48",
         
     | 
| 43 | 
         
            +
            ]
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
         
     | 
    	
        encoder/data_objects/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
         
     | 
| 2 | 
         
            +
            from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
         
     | 
    	
        encoder/data_objects/random_cycler.py
    ADDED
    
    | 
         @@ -0,0 +1,37 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import random
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            class RandomCycler:
         
     | 
| 4 | 
         
            +
                """
         
     | 
| 5 | 
         
            +
                Creates an internal copy of a sequence and allows access to its items in a constrained random 
         
     | 
| 6 | 
         
            +
                order. For a source sequence of n items and one or several consecutive queries of a total 
         
     | 
| 7 | 
         
            +
                of m items, the following guarantees hold (one implies the other):
         
     | 
| 8 | 
         
            +
                    - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
         
     | 
| 9 | 
         
            +
                    - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
         
     | 
| 10 | 
         
            +
                """
         
     | 
| 11 | 
         
            +
                
         
     | 
| 12 | 
         
            +
                def __init__(self, source):
         
     | 
| 13 | 
         
            +
                    if len(source) == 0:
         
     | 
| 14 | 
         
            +
                        raise Exception("Can't create RandomCycler from an empty collection")
         
     | 
| 15 | 
         
            +
                    self.all_items = list(source)
         
     | 
| 16 | 
         
            +
                    self.next_items = []
         
     | 
| 17 | 
         
            +
                
         
     | 
| 18 | 
         
            +
                def sample(self, count: int):
         
     | 
| 19 | 
         
            +
                    shuffle = lambda l: random.sample(l, len(l))
         
     | 
| 20 | 
         
            +
                    
         
     | 
| 21 | 
         
            +
                    out = []
         
     | 
| 22 | 
         
            +
                    while count > 0:
         
     | 
| 23 | 
         
            +
                        if count >= len(self.all_items):
         
     | 
| 24 | 
         
            +
                            out.extend(shuffle(list(self.all_items)))
         
     | 
| 25 | 
         
            +
                            count -= len(self.all_items)
         
     | 
| 26 | 
         
            +
                            continue
         
     | 
| 27 | 
         
            +
                        n = min(count, len(self.next_items))
         
     | 
| 28 | 
         
            +
                        out.extend(self.next_items[:n])
         
     | 
| 29 | 
         
            +
                        count -= n
         
     | 
| 30 | 
         
            +
                        self.next_items = self.next_items[n:]
         
     | 
| 31 | 
         
            +
                        if len(self.next_items) == 0:
         
     | 
| 32 | 
         
            +
                            self.next_items = shuffle(list(self.all_items))
         
     | 
| 33 | 
         
            +
                    return out
         
     | 
| 34 | 
         
            +
                
         
     | 
| 35 | 
         
            +
                def __next__(self):
         
     | 
| 36 | 
         
            +
                    return self.sample(1)[0]
         
     | 
| 37 | 
         
            +
             
     | 
    	
        encoder/data_objects/speaker.py
    ADDED
    
    | 
         @@ -0,0 +1,40 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from encoder.data_objects.random_cycler import RandomCycler
         
     | 
| 2 | 
         
            +
            from encoder.data_objects.utterance import Utterance
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # Contains the set of utterances of a single speaker
         
     | 
| 6 | 
         
            +
            class Speaker:
         
     | 
| 7 | 
         
            +
                def __init__(self, root: Path):
         
     | 
| 8 | 
         
            +
                    self.root = root
         
     | 
| 9 | 
         
            +
                    self.name = root.name
         
     | 
| 10 | 
         
            +
                    self.utterances = None
         
     | 
| 11 | 
         
            +
                    self.utterance_cycler = None
         
     | 
| 12 | 
         
            +
                    
         
     | 
| 13 | 
         
            +
                def _load_utterances(self):
         
     | 
| 14 | 
         
            +
                    with self.root.joinpath("_sources.txt").open("r") as sources_file:
         
     | 
| 15 | 
         
            +
                        sources = [l.split(",") for l in sources_file]
         
     | 
| 16 | 
         
            +
                    sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
         
     | 
| 17 | 
         
            +
                    self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
         
     | 
| 18 | 
         
            +
                    self.utterance_cycler = RandomCycler(self.utterances)
         
     | 
| 19 | 
         
            +
                           
         
     | 
| 20 | 
         
            +
                def random_partial(self, count, n_frames):
         
     | 
| 21 | 
         
            +
                    """
         
     | 
| 22 | 
         
            +
                    Samples a batch of <count> unique partial utterances from the disk in a way that all 
         
     | 
| 23 | 
         
            +
                    utterances come up at least once every two cycles and in a random order every time.
         
     | 
| 24 | 
         
            +
                    
         
     | 
| 25 | 
         
            +
                    :param count: The number of partial utterances to sample from the set of utterances from 
         
     | 
| 26 | 
         
            +
                    that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than 
         
     | 
| 27 | 
         
            +
                    the number of utterances available.
         
     | 
| 28 | 
         
            +
                    :param n_frames: The number of frames in the partial utterance.
         
     | 
| 29 | 
         
            +
                    :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, 
         
     | 
| 30 | 
         
            +
                    frames are the frames of the partial utterances and range is the range of the partial 
         
     | 
| 31 | 
         
            +
                    utterance with regard to the complete utterance.
         
     | 
| 32 | 
         
            +
                    """
         
     | 
| 33 | 
         
            +
                    if self.utterances is None:
         
     | 
| 34 | 
         
            +
                        self._load_utterances()
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    utterances = self.utterance_cycler.sample(count)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    a = [(u,) + u.random_partial(n_frames) for u in utterances]
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    return a
         
     | 
    	
        encoder/data_objects/speaker_batch.py
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            from typing import List
         
     | 
| 3 | 
         
            +
            from encoder.data_objects.speaker import Speaker
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class SpeakerBatch:
         
     | 
| 7 | 
         
            +
                def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
         
     | 
| 8 | 
         
            +
                    self.speakers = speakers
         
     | 
| 9 | 
         
            +
                    self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                    # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
         
     | 
| 12 | 
         
            +
                    # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
         
     | 
| 13 | 
         
            +
                    self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
         
     | 
    	
        encoder/data_objects/speaker_verification_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from encoder.data_objects.random_cycler import RandomCycler
         
     | 
| 2 | 
         
            +
            from encoder.data_objects.speaker_batch import SpeakerBatch
         
     | 
| 3 | 
         
            +
            from encoder.data_objects.speaker import Speaker
         
     | 
| 4 | 
         
            +
            from encoder.params_data import partials_n_frames
         
     | 
| 5 | 
         
            +
            from torch.utils.data import Dataset, DataLoader
         
     | 
| 6 | 
         
            +
            from pathlib import Path
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # TODO: improve with a pool of speakers for data efficiency
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class SpeakerVerificationDataset(Dataset):
         
     | 
| 11 | 
         
            +
                def __init__(self, datasets_root: Path):
         
     | 
| 12 | 
         
            +
                    self.root = datasets_root
         
     | 
| 13 | 
         
            +
                    speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
         
     | 
| 14 | 
         
            +
                    if len(speaker_dirs) == 0:
         
     | 
| 15 | 
         
            +
                        raise Exception("No speakers found. Make sure you are pointing to the directory "
         
     | 
| 16 | 
         
            +
                                        "containing all preprocessed speaker directories.")
         
     | 
| 17 | 
         
            +
                    self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
         
     | 
| 18 | 
         
            +
                    self.speaker_cycler = RandomCycler(self.speakers)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def __len__(self):
         
     | 
| 21 | 
         
            +
                    return int(1e10)
         
     | 
| 22 | 
         
            +
                    
         
     | 
| 23 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 24 | 
         
            +
                    return next(self.speaker_cycler)
         
     | 
| 25 | 
         
            +
                
         
     | 
| 26 | 
         
            +
                def get_logs(self):
         
     | 
| 27 | 
         
            +
                    log_string = ""
         
     | 
| 28 | 
         
            +
                    for log_fpath in self.root.glob("*.txt"):
         
     | 
| 29 | 
         
            +
                        with log_fpath.open("r") as log_file:
         
     | 
| 30 | 
         
            +
                            log_string += "".join(log_file.readlines())
         
     | 
| 31 | 
         
            +
                    return log_string
         
     | 
| 32 | 
         
            +
                
         
     | 
| 33 | 
         
            +
                
         
     | 
| 34 | 
         
            +
            class SpeakerVerificationDataLoader(DataLoader):
         
     | 
| 35 | 
         
            +
                def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, 
         
     | 
| 36 | 
         
            +
                             batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, 
         
     | 
| 37 | 
         
            +
                             worker_init_fn=None):
         
     | 
| 38 | 
         
            +
                    self.utterances_per_speaker = utterances_per_speaker
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    super().__init__(
         
     | 
| 41 | 
         
            +
                        dataset=dataset, 
         
     | 
| 42 | 
         
            +
                        batch_size=speakers_per_batch, 
         
     | 
| 43 | 
         
            +
                        shuffle=False, 
         
     | 
| 44 | 
         
            +
                        sampler=sampler, 
         
     | 
| 45 | 
         
            +
                        batch_sampler=batch_sampler, 
         
     | 
| 46 | 
         
            +
                        num_workers=num_workers,
         
     | 
| 47 | 
         
            +
                        collate_fn=self.collate, 
         
     | 
| 48 | 
         
            +
                        pin_memory=pin_memory, 
         
     | 
| 49 | 
         
            +
                        drop_last=False, 
         
     | 
| 50 | 
         
            +
                        timeout=timeout, 
         
     | 
| 51 | 
         
            +
                        worker_init_fn=worker_init_fn
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def collate(self, speakers):
         
     | 
| 55 | 
         
            +
                    return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) 
         
     | 
| 56 | 
         
            +
                
         
     | 
    	
        encoder/data_objects/utterance.py
    ADDED
    
    | 
         @@ -0,0 +1,26 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            class Utterance:
         
     | 
| 5 | 
         
            +
                def __init__(self, frames_fpath, wave_fpath):
         
     | 
| 6 | 
         
            +
                    self.frames_fpath = frames_fpath
         
     | 
| 7 | 
         
            +
                    self.wave_fpath = wave_fpath
         
     | 
| 8 | 
         
            +
                    
         
     | 
| 9 | 
         
            +
                def get_frames(self):
         
     | 
| 10 | 
         
            +
                    return np.load(self.frames_fpath)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def random_partial(self, n_frames):
         
     | 
| 13 | 
         
            +
                    """
         
     | 
| 14 | 
         
            +
                    Crops the frames into a partial utterance of n_frames
         
     | 
| 15 | 
         
            +
                    
         
     | 
| 16 | 
         
            +
                    :param n_frames: The number of frames of the partial utterance
         
     | 
| 17 | 
         
            +
                    :return: the partial utterance frames and a tuple indicating the start and end of the 
         
     | 
| 18 | 
         
            +
                    partial utterance in the complete utterance.
         
     | 
| 19 | 
         
            +
                    """
         
     | 
| 20 | 
         
            +
                    frames = self.get_frames()
         
     | 
| 21 | 
         
            +
                    if frames.shape[0] == n_frames:
         
     | 
| 22 | 
         
            +
                        start = 0
         
     | 
| 23 | 
         
            +
                    else:
         
     | 
| 24 | 
         
            +
                        start = np.random.randint(0, frames.shape[0] - n_frames)
         
     | 
| 25 | 
         
            +
                    end = start + n_frames
         
     | 
| 26 | 
         
            +
                    return frames[start:end], (start, end)
         
     | 
    	
        encoder/inference.py
    ADDED
    
    | 
         @@ -0,0 +1,178 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from encoder.params_data import *
         
     | 
| 2 | 
         
            +
            from encoder.model import SpeakerEncoder
         
     | 
| 3 | 
         
            +
            from encoder.audio import preprocess_wav   # We want to expose this function from here
         
     | 
| 4 | 
         
            +
            from matplotlib import cm
         
     | 
| 5 | 
         
            +
            from encoder import audio
         
     | 
| 6 | 
         
            +
            from pathlib import Path
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            _model = None # type: SpeakerEncoder
         
     | 
| 11 | 
         
            +
            _device = None # type: torch.device
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def load_model(weights_fpath: Path, device=None):
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
                Loads the model in memory. If this function is not explicitely called, it will be run on the
         
     | 
| 17 | 
         
            +
                first call to embed_frames() with the default weights file.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                :param weights_fpath: the path to saved model weights.
         
     | 
| 20 | 
         
            +
                :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
         
     | 
| 21 | 
         
            +
                model will be loaded and will run on this device. Outputs will however always be on the cpu.
         
     | 
| 22 | 
         
            +
                If None, will default to your GPU if it"s available, otherwise your CPU.
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
                # TODO: I think the slow loading of the encoder might have something to do with the device it
         
     | 
| 25 | 
         
            +
                #   was saved on. Worth investigating.
         
     | 
| 26 | 
         
            +
                global _model, _device
         
     | 
| 27 | 
         
            +
                if device is None:
         
     | 
| 28 | 
         
            +
                    _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 29 | 
         
            +
                elif isinstance(device, str):
         
     | 
| 30 | 
         
            +
                    _device = torch.device(device)
         
     | 
| 31 | 
         
            +
                _model = SpeakerEncoder(_device, torch.device("cpu"))
         
     | 
| 32 | 
         
            +
                checkpoint = torch.load(weights_fpath, _device)
         
     | 
| 33 | 
         
            +
                _model.load_state_dict(checkpoint["model_state"])
         
     | 
| 34 | 
         
            +
                _model.eval()
         
     | 
| 35 | 
         
            +
                print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def is_loaded():
         
     | 
| 39 | 
         
            +
                return _model is not None
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            def embed_frames_batch(frames_batch):
         
     | 
| 43 | 
         
            +
                """
         
     | 
| 44 | 
         
            +
                Computes embeddings for a batch of mel spectrogram.
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
         
     | 
| 47 | 
         
            +
                (batch_size, n_frames, n_channels)
         
     | 
| 48 | 
         
            +
                :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
         
     | 
| 49 | 
         
            +
                """
         
     | 
| 50 | 
         
            +
                if _model is None:
         
     | 
| 51 | 
         
            +
                    raise Exception("Model was not loaded. Call load_model() before inference.")
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                frames = torch.from_numpy(frames_batch).to(_device)
         
     | 
| 54 | 
         
            +
                embed = _model.forward(frames).detach().cpu().numpy()
         
     | 
| 55 | 
         
            +
                return embed
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
         
     | 
| 59 | 
         
            +
                                       min_pad_coverage=0.75, overlap=0.5):
         
     | 
| 60 | 
         
            +
                """
         
     | 
| 61 | 
         
            +
                Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
         
     | 
| 62 | 
         
            +
                partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
         
     | 
| 63 | 
         
            +
                spectrogram slices are returned, so as to make each partial utterance waveform correspond to
         
     | 
| 64 | 
         
            +
                its spectrogram. This function assumes that the mel spectrogram parameters used are those
         
     | 
| 65 | 
         
            +
                defined in params_data.py.
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                The returned ranges may be indexing further than the length of the waveform. It is
         
     | 
| 68 | 
         
            +
                recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                :param n_samples: the number of samples in the waveform
         
     | 
| 71 | 
         
            +
                :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
         
     | 
| 72 | 
         
            +
                utterance
         
     | 
| 73 | 
         
            +
                :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
         
     | 
| 74 | 
         
            +
                enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
         
     | 
| 75 | 
         
            +
                then the last partial utterance will be considered, as if we padded the audio. Otherwise,
         
     | 
| 76 | 
         
            +
                it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
         
     | 
| 77 | 
         
            +
                utterance, this parameter is ignored so that the function always returns at least 1 slice.
         
     | 
| 78 | 
         
            +
                :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
         
     | 
| 79 | 
         
            +
                utterances are entirely disjoint.
         
     | 
| 80 | 
         
            +
                :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
         
     | 
| 81 | 
         
            +
                respectively the waveform and the mel spectrogram with these slices to obtain the partial
         
     | 
| 82 | 
         
            +
                utterances.
         
     | 
| 83 | 
         
            +
                """
         
     | 
| 84 | 
         
            +
                assert 0 <= overlap < 1
         
     | 
| 85 | 
         
            +
                assert 0 < min_pad_coverage <= 1
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                samples_per_frame = int((sampling_rate * mel_window_step / 1000))
         
     | 
| 88 | 
         
            +
                n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
         
     | 
| 89 | 
         
            +
                frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                # Compute the slices
         
     | 
| 92 | 
         
            +
                wav_slices, mel_slices = [], []
         
     | 
| 93 | 
         
            +
                steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
         
     | 
| 94 | 
         
            +
                for i in range(0, steps, frame_step):
         
     | 
| 95 | 
         
            +
                    mel_range = np.array([i, i + partial_utterance_n_frames])
         
     | 
| 96 | 
         
            +
                    wav_range = mel_range * samples_per_frame
         
     | 
| 97 | 
         
            +
                    mel_slices.append(slice(*mel_range))
         
     | 
| 98 | 
         
            +
                    wav_slices.append(slice(*wav_range))
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                # Evaluate whether extra padding is warranted or not
         
     | 
| 101 | 
         
            +
                last_wav_range = wav_slices[-1]
         
     | 
| 102 | 
         
            +
                coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
         
     | 
| 103 | 
         
            +
                if coverage < min_pad_coverage and len(mel_slices) > 1:
         
     | 
| 104 | 
         
            +
                    mel_slices = mel_slices[:-1]
         
     | 
| 105 | 
         
            +
                    wav_slices = wav_slices[:-1]
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                return wav_slices, mel_slices
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
         
     | 
| 111 | 
         
            +
                """
         
     | 
| 112 | 
         
            +
                Computes an embedding for a single utterance.
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                # TODO: handle multiple wavs to benefit from batching on GPU
         
     | 
| 115 | 
         
            +
                :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
         
     | 
| 116 | 
         
            +
                :param using_partials: if True, then the utterance is split in partial utterances of
         
     | 
| 117 | 
         
            +
                <partial_utterance_n_frames> frames and the utterance embedding is computed from their
         
     | 
| 118 | 
         
            +
                normalized average. If False, the utterance is instead computed from feeding the entire
         
     | 
| 119 | 
         
            +
                spectogram to the network.
         
     | 
| 120 | 
         
            +
                :param return_partials: if True, the partial embeddings will also be returned along with the
         
     | 
| 121 | 
         
            +
                wav slices that correspond to the partial embeddings.
         
     | 
| 122 | 
         
            +
                :param kwargs: additional arguments to compute_partial_splits()
         
     | 
| 123 | 
         
            +
                :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
         
     | 
| 124 | 
         
            +
                <return_partials> is True, the partial utterances as a numpy array of float32 of shape
         
     | 
| 125 | 
         
            +
                (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
         
     | 
| 126 | 
         
            +
                returned. If <using_partials> is simultaneously set to False, both these values will be None
         
     | 
| 127 | 
         
            +
                instead.
         
     | 
| 128 | 
         
            +
                """
         
     | 
| 129 | 
         
            +
                # Process the entire utterance if not using partials
         
     | 
| 130 | 
         
            +
                if not using_partials:
         
     | 
| 131 | 
         
            +
                    frames = audio.wav_to_mel_spectrogram(wav)
         
     | 
| 132 | 
         
            +
                    embed = embed_frames_batch(frames[None, ...])[0]
         
     | 
| 133 | 
         
            +
                    if return_partials:
         
     | 
| 134 | 
         
            +
                        return embed, None, None
         
     | 
| 135 | 
         
            +
                    return embed
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                # Compute where to split the utterance into partials and pad if necessary
         
     | 
| 138 | 
         
            +
                wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
         
     | 
| 139 | 
         
            +
                max_wave_length = wave_slices[-1].stop
         
     | 
| 140 | 
         
            +
                if max_wave_length >= len(wav):
         
     | 
| 141 | 
         
            +
                    wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                # Split the utterance into partials
         
     | 
| 144 | 
         
            +
                frames = audio.wav_to_mel_spectrogram(wav)
         
     | 
| 145 | 
         
            +
                frames_batch = np.array([frames[s] for s in mel_slices])
         
     | 
| 146 | 
         
            +
                partial_embeds = embed_frames_batch(frames_batch)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                # Compute the utterance embedding from the partial embeddings
         
     | 
| 149 | 
         
            +
                raw_embed = np.mean(partial_embeds, axis=0)
         
     | 
| 150 | 
         
            +
                embed = raw_embed / np.linalg.norm(raw_embed, 2)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                if return_partials:
         
     | 
| 153 | 
         
            +
                    return embed, partial_embeds, wave_slices
         
     | 
| 154 | 
         
            +
                return embed
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            def embed_speaker(wavs, **kwargs):
         
     | 
| 158 | 
         
            +
                raise NotImplemented()
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
            def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
         
     | 
| 162 | 
         
            +
                import matplotlib.pyplot as plt
         
     | 
| 163 | 
         
            +
                if ax is None:
         
     | 
| 164 | 
         
            +
                    ax = plt.gca()
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                if shape is None:
         
     | 
| 167 | 
         
            +
                    height = int(np.sqrt(len(embed)))
         
     | 
| 168 | 
         
            +
                    shape = (height, -1)
         
     | 
| 169 | 
         
            +
                embed = embed.reshape(shape)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                cmap = cm.get_cmap()
         
     | 
| 172 | 
         
            +
                mappable = ax.imshow(embed, cmap=cmap)
         
     | 
| 173 | 
         
            +
                cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
         
     | 
| 174 | 
         
            +
                sm = cm.ScalarMappable(cmap=cmap)
         
     | 
| 175 | 
         
            +
                sm.set_clim(*color_range)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                ax.set_xticks([]), ax.set_yticks([])
         
     | 
| 178 | 
         
            +
                ax.set_title(title)
         
     | 
    	
        encoder/model.py
    ADDED
    
    | 
         @@ -0,0 +1,135 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from encoder.params_model import *
         
     | 
| 2 | 
         
            +
            from encoder.params_data import *
         
     | 
| 3 | 
         
            +
            from scipy.interpolate import interp1d
         
     | 
| 4 | 
         
            +
            from sklearn.metrics import roc_curve
         
     | 
| 5 | 
         
            +
            from torch.nn.utils import clip_grad_norm_
         
     | 
| 6 | 
         
            +
            from scipy.optimize import brentq
         
     | 
| 7 | 
         
            +
            from torch import nn
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class SpeakerEncoder(nn.Module):
         
     | 
| 13 | 
         
            +
                def __init__(self, device, loss_device):
         
     | 
| 14 | 
         
            +
                    super().__init__()
         
     | 
| 15 | 
         
            +
                    self.loss_device = loss_device
         
     | 
| 16 | 
         
            +
                    
         
     | 
| 17 | 
         
            +
                    # Network defition
         
     | 
| 18 | 
         
            +
                    self.lstm = nn.LSTM(input_size=mel_n_channels,
         
     | 
| 19 | 
         
            +
                                        hidden_size=model_hidden_size, 
         
     | 
| 20 | 
         
            +
                                        num_layers=model_num_layers, 
         
     | 
| 21 | 
         
            +
                                        batch_first=True).to(device)
         
     | 
| 22 | 
         
            +
                    self.linear = nn.Linear(in_features=model_hidden_size, 
         
     | 
| 23 | 
         
            +
                                            out_features=model_embedding_size).to(device)
         
     | 
| 24 | 
         
            +
                    self.relu = torch.nn.ReLU().to(device)
         
     | 
| 25 | 
         
            +
                    
         
     | 
| 26 | 
         
            +
                    # Cosine similarity scaling (with fixed initial parameter values)
         
     | 
| 27 | 
         
            +
                    self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
         
     | 
| 28 | 
         
            +
                    self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    # Loss
         
     | 
| 31 | 
         
            +
                    self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
         
     | 
| 32 | 
         
            +
                    
         
     | 
| 33 | 
         
            +
                def do_gradient_ops(self):
         
     | 
| 34 | 
         
            +
                    # Gradient scale
         
     | 
| 35 | 
         
            +
                    self.similarity_weight.grad *= 0.01
         
     | 
| 36 | 
         
            +
                    self.similarity_bias.grad *= 0.01
         
     | 
| 37 | 
         
            +
                        
         
     | 
| 38 | 
         
            +
                    # Gradient clipping
         
     | 
| 39 | 
         
            +
                    clip_grad_norm_(self.parameters(), 3, norm_type=2)
         
     | 
| 40 | 
         
            +
                
         
     | 
| 41 | 
         
            +
                def forward(self, utterances, hidden_init=None):
         
     | 
| 42 | 
         
            +
                    """
         
     | 
| 43 | 
         
            +
                    Computes the embeddings of a batch of utterance spectrograms.
         
     | 
| 44 | 
         
            +
                    
         
     | 
| 45 | 
         
            +
                    :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 
         
     | 
| 46 | 
         
            +
                    (batch_size, n_frames, n_channels) 
         
     | 
| 47 | 
         
            +
                    :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 
         
     | 
| 48 | 
         
            +
                    batch_size, hidden_size). Will default to a tensor of zeros if None.
         
     | 
| 49 | 
         
            +
                    :return: the embeddings as a tensor of shape (batch_size, embedding_size)
         
     | 
| 50 | 
         
            +
                    """
         
     | 
| 51 | 
         
            +
                    # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
         
     | 
| 52 | 
         
            +
                    # and the final cell state.
         
     | 
| 53 | 
         
            +
                    out, (hidden, cell) = self.lstm(utterances, hidden_init)
         
     | 
| 54 | 
         
            +
                    
         
     | 
| 55 | 
         
            +
                    # We take only the hidden state of the last layer
         
     | 
| 56 | 
         
            +
                    embeds_raw = self.relu(self.linear(hidden[-1]))
         
     | 
| 57 | 
         
            +
                    
         
     | 
| 58 | 
         
            +
                    # L2-normalize it
         
     | 
| 59 | 
         
            +
                    embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)        
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    return embeds
         
     | 
| 62 | 
         
            +
                
         
     | 
| 63 | 
         
            +
                def similarity_matrix(self, embeds):
         
     | 
| 64 | 
         
            +
                    """
         
     | 
| 65 | 
         
            +
                    Computes the similarity matrix according the section 2.1 of GE2E.
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 
         
     | 
| 68 | 
         
            +
                    utterances_per_speaker, embedding_size)
         
     | 
| 69 | 
         
            +
                    :return: the similarity matrix as a tensor of shape (speakers_per_batch,
         
     | 
| 70 | 
         
            +
                    utterances_per_speaker, speakers_per_batch)
         
     | 
| 71 | 
         
            +
                    """
         
     | 
| 72 | 
         
            +
                    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
         
     | 
| 73 | 
         
            +
                    
         
     | 
| 74 | 
         
            +
                    # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
         
     | 
| 75 | 
         
            +
                    centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
         
     | 
| 76 | 
         
            +
                    centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    # Exclusive centroids (1 per utterance)
         
     | 
| 79 | 
         
            +
                    centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
         
     | 
| 80 | 
         
            +
                    centroids_excl /= (utterances_per_speaker - 1)
         
     | 
| 81 | 
         
            +
                    centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
         
     | 
| 84 | 
         
            +
                    # product of these vectors (which is just an element-wise multiplication reduced by a sum).
         
     | 
| 85 | 
         
            +
                    # We vectorize the computation for efficiency.
         
     | 
| 86 | 
         
            +
                    sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
         
     | 
| 87 | 
         
            +
                                             speakers_per_batch).to(self.loss_device)
         
     | 
| 88 | 
         
            +
                    mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
         
     | 
| 89 | 
         
            +
                    for j in range(speakers_per_batch):
         
     | 
| 90 | 
         
            +
                        mask = np.where(mask_matrix[j])[0]
         
     | 
| 91 | 
         
            +
                        sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
         
     | 
| 92 | 
         
            +
                        sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
         
     | 
| 93 | 
         
            +
                    
         
     | 
| 94 | 
         
            +
                    ## Even more vectorized version (slower maybe because of transpose)
         
     | 
| 95 | 
         
            +
                    # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
         
     | 
| 96 | 
         
            +
                    #                           ).to(self.loss_device)
         
     | 
| 97 | 
         
            +
                    # eye = np.eye(speakers_per_batch, dtype=np.int)
         
     | 
| 98 | 
         
            +
                    # mask = np.where(1 - eye)
         
     | 
| 99 | 
         
            +
                    # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
         
     | 
| 100 | 
         
            +
                    # mask = np.where(eye)
         
     | 
| 101 | 
         
            +
                    # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
         
     | 
| 102 | 
         
            +
                    # sim_matrix2 = sim_matrix2.transpose(1, 2)
         
     | 
| 103 | 
         
            +
                    
         
     | 
| 104 | 
         
            +
                    sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
         
     | 
| 105 | 
         
            +
                    return sim_matrix
         
     | 
| 106 | 
         
            +
                
         
     | 
| 107 | 
         
            +
                def loss(self, embeds):
         
     | 
| 108 | 
         
            +
                    """
         
     | 
| 109 | 
         
            +
                    Computes the softmax loss according the section 2.1 of GE2E.
         
     | 
| 110 | 
         
            +
                    
         
     | 
| 111 | 
         
            +
                    :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 
         
     | 
| 112 | 
         
            +
                    utterances_per_speaker, embedding_size)
         
     | 
| 113 | 
         
            +
                    :return: the loss and the EER for this batch of embeddings.
         
     | 
| 114 | 
         
            +
                    """
         
     | 
| 115 | 
         
            +
                    speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
         
     | 
| 116 | 
         
            +
                    
         
     | 
| 117 | 
         
            +
                    # Loss
         
     | 
| 118 | 
         
            +
                    sim_matrix = self.similarity_matrix(embeds)
         
     | 
| 119 | 
         
            +
                    sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, 
         
     | 
| 120 | 
         
            +
                                                     speakers_per_batch))
         
     | 
| 121 | 
         
            +
                    ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
         
     | 
| 122 | 
         
            +
                    target = torch.from_numpy(ground_truth).long().to(self.loss_device)
         
     | 
| 123 | 
         
            +
                    loss = self.loss_fn(sim_matrix, target)
         
     | 
| 124 | 
         
            +
                    
         
     | 
| 125 | 
         
            +
                    # EER (not backpropagated)
         
     | 
| 126 | 
         
            +
                    with torch.no_grad():
         
     | 
| 127 | 
         
            +
                        inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
         
     | 
| 128 | 
         
            +
                        labels = np.array([inv_argmax(i) for i in ground_truth])
         
     | 
| 129 | 
         
            +
                        preds = sim_matrix.detach().cpu().numpy()
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                        # Snippet from https://yangcha.github.io/EER-ROC/
         
     | 
| 132 | 
         
            +
                        fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())           
         
     | 
| 133 | 
         
            +
                        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
         
     | 
| 134 | 
         
            +
                        
         
     | 
| 135 | 
         
            +
                    return loss, eer
         
     | 
    	
        encoder/params_data.py
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
| 2 | 
         
            +
            ## Mel-filterbank
         
     | 
| 3 | 
         
            +
            mel_window_length = 25  # In milliseconds
         
     | 
| 4 | 
         
            +
            mel_window_step = 10    # In milliseconds
         
     | 
| 5 | 
         
            +
            mel_n_channels = 40
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ## Audio
         
     | 
| 9 | 
         
            +
            sampling_rate = 16000
         
     | 
| 10 | 
         
            +
            # Number of spectrogram frames in a partial utterance
         
     | 
| 11 | 
         
            +
            partials_n_frames = 160     # 1600 ms
         
     | 
| 12 | 
         
            +
            # Number of spectrogram frames at inference
         
     | 
| 13 | 
         
            +
            inference_n_frames = 80     #  800 ms
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            ## Voice Activation Detection
         
     | 
| 17 | 
         
            +
            # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
         
     | 
| 18 | 
         
            +
            # This sets the granularity of the VAD. Should not need to be changed.
         
     | 
| 19 | 
         
            +
            vad_window_length = 30  # In milliseconds
         
     | 
| 20 | 
         
            +
            # Number of frames to average together when performing the moving average smoothing.
         
     | 
| 21 | 
         
            +
            # The larger this value, the larger the VAD variations must be to not get smoothed out. 
         
     | 
| 22 | 
         
            +
            vad_moving_average_width = 8
         
     | 
| 23 | 
         
            +
            # Maximum number of consecutive silent frames a segment can have.
         
     | 
| 24 | 
         
            +
            vad_max_silence_length = 6
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            ## Audio volume normalization
         
     | 
| 28 | 
         
            +
            audio_norm_target_dBFS = -30
         
     | 
| 29 | 
         
            +
             
     | 
    	
        encoder/params_model.py
    ADDED
    
    | 
         @@ -0,0 +1,11 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
| 2 | 
         
            +
            ## Model parameters
         
     | 
| 3 | 
         
            +
            model_hidden_size = 256
         
     | 
| 4 | 
         
            +
            model_embedding_size = 256
         
     | 
| 5 | 
         
            +
            model_num_layers = 3
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ## Training parameters
         
     | 
| 9 | 
         
            +
            learning_rate_init = 1e-4
         
     | 
| 10 | 
         
            +
            speakers_per_batch = 64
         
     | 
| 11 | 
         
            +
            utterances_per_speaker = 10
         
     | 
    	
        encoder/preprocess.py
    ADDED
    
    | 
         @@ -0,0 +1,184 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from datetime import datetime
         
     | 
| 2 | 
         
            +
            from functools import partial
         
     | 
| 3 | 
         
            +
            from multiprocessing import Pool
         
     | 
| 4 | 
         
            +
            from pathlib import Path
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            from tqdm import tqdm
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from encoder import audio
         
     | 
| 10 | 
         
            +
            from encoder.config import librispeech_datasets, anglophone_nationalites
         
     | 
| 11 | 
         
            +
            from encoder.params_data import *
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            _AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            class DatasetLog:
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
                Registers metadata about the dataset in a text file.
         
     | 
| 19 | 
         
            +
                """
         
     | 
| 20 | 
         
            +
                def __init__(self, root, name):
         
     | 
| 21 | 
         
            +
                    self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
         
     | 
| 22 | 
         
            +
                    self.sample_data = dict()
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
         
     | 
| 25 | 
         
            +
                    self.write_line("Creating dataset %s on %s" % (name, start_time))
         
     | 
| 26 | 
         
            +
                    self.write_line("-----")
         
     | 
| 27 | 
         
            +
                    self._log_params()
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def _log_params(self):
         
     | 
| 30 | 
         
            +
                    from encoder import params_data
         
     | 
| 31 | 
         
            +
                    self.write_line("Parameter values:")
         
     | 
| 32 | 
         
            +
                    for param_name in (p for p in dir(params_data) if not p.startswith("__")):
         
     | 
| 33 | 
         
            +
                        value = getattr(params_data, param_name)
         
     | 
| 34 | 
         
            +
                        self.write_line("\t%s: %s" % (param_name, value))
         
     | 
| 35 | 
         
            +
                    self.write_line("-----")
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def write_line(self, line):
         
     | 
| 38 | 
         
            +
                    self.text_file.write("%s\n" % line)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def add_sample(self, **kwargs):
         
     | 
| 41 | 
         
            +
                    for param_name, value in kwargs.items():
         
     | 
| 42 | 
         
            +
                        if not param_name in self.sample_data:
         
     | 
| 43 | 
         
            +
                            self.sample_data[param_name] = []
         
     | 
| 44 | 
         
            +
                        self.sample_data[param_name].append(value)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def finalize(self):
         
     | 
| 47 | 
         
            +
                    self.write_line("Statistics:")
         
     | 
| 48 | 
         
            +
                    for param_name, values in self.sample_data.items():
         
     | 
| 49 | 
         
            +
                        self.write_line("\t%s:" % param_name)
         
     | 
| 50 | 
         
            +
                        self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
         
     | 
| 51 | 
         
            +
                        self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
         
     | 
| 52 | 
         
            +
                    self.write_line("-----")
         
     | 
| 53 | 
         
            +
                    end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
         
     | 
| 54 | 
         
            +
                    self.write_line("Finished on %s" % end_time)
         
     | 
| 55 | 
         
            +
                    self.text_file.close()
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
         
     | 
| 59 | 
         
            +
                dataset_root = datasets_root.joinpath(dataset_name)
         
     | 
| 60 | 
         
            +
                if not dataset_root.exists():
         
     | 
| 61 | 
         
            +
                    print("Couldn\'t find %s, skipping this dataset." % dataset_root)
         
     | 
| 62 | 
         
            +
                    return None, None
         
     | 
| 63 | 
         
            +
                return dataset_root, DatasetLog(out_dir, dataset_name)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
         
     | 
| 67 | 
         
            +
                # Give a name to the speaker that includes its dataset
         
     | 
| 68 | 
         
            +
                speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                # Create an output directory with that name, as well as a txt file containing a
         
     | 
| 71 | 
         
            +
                # reference to each source file.
         
     | 
| 72 | 
         
            +
                speaker_out_dir = out_dir.joinpath(speaker_name)
         
     | 
| 73 | 
         
            +
                speaker_out_dir.mkdir(exist_ok=True)
         
     | 
| 74 | 
         
            +
                sources_fpath = speaker_out_dir.joinpath("_sources.txt")
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                # There's a possibility that the preprocessing was interrupted earlier, check if
         
     | 
| 77 | 
         
            +
                # there already is a sources file.
         
     | 
| 78 | 
         
            +
                if sources_fpath.exists():
         
     | 
| 79 | 
         
            +
                    try:
         
     | 
| 80 | 
         
            +
                        with sources_fpath.open("r") as sources_file:
         
     | 
| 81 | 
         
            +
                            existing_fnames = {line.split(",")[0] for line in sources_file}
         
     | 
| 82 | 
         
            +
                    except:
         
     | 
| 83 | 
         
            +
                        existing_fnames = {}
         
     | 
| 84 | 
         
            +
                else:
         
     | 
| 85 | 
         
            +
                    existing_fnames = {}
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                # Gather all audio files for that speaker recursively
         
     | 
| 88 | 
         
            +
                sources_file = sources_fpath.open("a" if skip_existing else "w")
         
     | 
| 89 | 
         
            +
                audio_durs = []
         
     | 
| 90 | 
         
            +
                for extension in _AUDIO_EXTENSIONS:
         
     | 
| 91 | 
         
            +
                    for in_fpath in speaker_dir.glob("**/*.%s" % extension):
         
     | 
| 92 | 
         
            +
                        # Check if the target output file already exists
         
     | 
| 93 | 
         
            +
                        out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
         
     | 
| 94 | 
         
            +
                        out_fname = out_fname.replace(".%s" % extension, ".npy")
         
     | 
| 95 | 
         
            +
                        if skip_existing and out_fname in existing_fnames:
         
     | 
| 96 | 
         
            +
                            continue
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                        # Load and preprocess the waveform
         
     | 
| 99 | 
         
            +
                        wav = audio.preprocess_wav(in_fpath)
         
     | 
| 100 | 
         
            +
                        if len(wav) == 0:
         
     | 
| 101 | 
         
            +
                            continue
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                        # Create the mel spectrogram, discard those that are too short
         
     | 
| 104 | 
         
            +
                        frames = audio.wav_to_mel_spectrogram(wav)
         
     | 
| 105 | 
         
            +
                        if len(frames) < partials_n_frames:
         
     | 
| 106 | 
         
            +
                            continue
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                        out_fpath = speaker_out_dir.joinpath(out_fname)
         
     | 
| 109 | 
         
            +
                        np.save(out_fpath, frames)
         
     | 
| 110 | 
         
            +
                        sources_file.write("%s,%s\n" % (out_fname, in_fpath))
         
     | 
| 111 | 
         
            +
                        audio_durs.append(len(wav) / sampling_rate)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                sources_file.close()
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                return audio_durs
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
         
     | 
| 119 | 
         
            +
                print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                # Process the utterances for each speaker
         
     | 
| 122 | 
         
            +
                work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
         
     | 
| 123 | 
         
            +
                with Pool(4) as pool:
         
     | 
| 124 | 
         
            +
                    tasks = pool.imap(work_fn, speaker_dirs)
         
     | 
| 125 | 
         
            +
                    for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
         
     | 
| 126 | 
         
            +
                        for sample_dur in sample_durs:
         
     | 
| 127 | 
         
            +
                            logger.add_sample(duration=sample_dur)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                logger.finalize()
         
     | 
| 130 | 
         
            +
                print("Done preprocessing %s.\n" % dataset_name)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
         
     | 
| 134 | 
         
            +
                for dataset_name in librispeech_datasets["train"]["other"]:
         
     | 
| 135 | 
         
            +
                    # Initialize the preprocessing
         
     | 
| 136 | 
         
            +
                    dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
         
     | 
| 137 | 
         
            +
                    if not dataset_root:
         
     | 
| 138 | 
         
            +
                        return
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # Preprocess all speakers
         
     | 
| 141 | 
         
            +
                    speaker_dirs = list(dataset_root.glob("*"))
         
     | 
| 142 | 
         
            +
                    _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
         
     | 
| 146 | 
         
            +
                # Initialize the preprocessing
         
     | 
| 147 | 
         
            +
                dataset_name = "VoxCeleb1"
         
     | 
| 148 | 
         
            +
                dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
         
     | 
| 149 | 
         
            +
                if not dataset_root:
         
     | 
| 150 | 
         
            +
                    return
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                # Get the contents of the meta file
         
     | 
| 153 | 
         
            +
                with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
         
     | 
| 154 | 
         
            +
                    metadata = [line.split("\t") for line in metafile][1:]
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                # Select the ID and the nationality, filter out non-anglophone speakers
         
     | 
| 157 | 
         
            +
                nationalities = {line[0]: line[3] for line in metadata}
         
     | 
| 158 | 
         
            +
                keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
         
     | 
| 159 | 
         
            +
                                    nationality.lower() in anglophone_nationalites]
         
     | 
| 160 | 
         
            +
                print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
         
     | 
| 161 | 
         
            +
                      (len(keep_speaker_ids), len(nationalities)))
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                # Get the speaker directories for anglophone speakers only
         
     | 
| 164 | 
         
            +
                speaker_dirs = dataset_root.joinpath("wav").glob("*")
         
     | 
| 165 | 
         
            +
                speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
         
     | 
| 166 | 
         
            +
                                speaker_dir.name in keep_speaker_ids]
         
     | 
| 167 | 
         
            +
                print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
         
     | 
| 168 | 
         
            +
                      (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                # Preprocess all speakers
         
     | 
| 171 | 
         
            +
                _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
            def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
         
     | 
| 175 | 
         
            +
                # Initialize the preprocessing
         
     | 
| 176 | 
         
            +
                dataset_name = "VoxCeleb2"
         
     | 
| 177 | 
         
            +
                dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
         
     | 
| 178 | 
         
            +
                if not dataset_root:
         
     | 
| 179 | 
         
            +
                    return
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                # Get the speaker directories
         
     | 
| 182 | 
         
            +
                # Preprocess all speakers
         
     | 
| 183 | 
         
            +
                speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
         
     | 
| 184 | 
         
            +
                _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
         
     | 
    	
        encoder/train.py
    ADDED
    
    | 
         @@ -0,0 +1,125 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from pathlib import Path
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
         
     | 
| 6 | 
         
            +
            from encoder.model import SpeakerEncoder
         
     | 
| 7 | 
         
            +
            from encoder.params_model import *
         
     | 
| 8 | 
         
            +
            from encoder.visualizations import Visualizations
         
     | 
| 9 | 
         
            +
            from utils.profiler import Profiler
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def sync(device: torch.device):
         
     | 
| 13 | 
         
            +
                # For correct profiling (cuda operations are async)
         
     | 
| 14 | 
         
            +
                if device.type == "cuda":
         
     | 
| 15 | 
         
            +
                    torch.cuda.synchronize(device)
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
         
     | 
| 19 | 
         
            +
                      backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
         
     | 
| 20 | 
         
            +
                      no_visdom: bool):
         
     | 
| 21 | 
         
            +
                # Create a dataset and a dataloader
         
     | 
| 22 | 
         
            +
                dataset = SpeakerVerificationDataset(clean_data_root)
         
     | 
| 23 | 
         
            +
                loader = SpeakerVerificationDataLoader(
         
     | 
| 24 | 
         
            +
                    dataset,
         
     | 
| 25 | 
         
            +
                    speakers_per_batch,
         
     | 
| 26 | 
         
            +
                    utterances_per_speaker,
         
     | 
| 27 | 
         
            +
                    num_workers=4,
         
     | 
| 28 | 
         
            +
                )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                # Setup the device on which to run the forward pass and the loss. These can be different,
         
     | 
| 31 | 
         
            +
                # because the forward pass is faster on the GPU whereas the loss is often (depending on your
         
     | 
| 32 | 
         
            +
                # hyperparameters) faster on the CPU.
         
     | 
| 33 | 
         
            +
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 34 | 
         
            +
                # FIXME: currently, the gradient is None if loss_device is cuda
         
     | 
| 35 | 
         
            +
                loss_device = torch.device("cpu")
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                # Create the model and the optimizer
         
     | 
| 38 | 
         
            +
                model = SpeakerEncoder(device, loss_device)
         
     | 
| 39 | 
         
            +
                optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
         
     | 
| 40 | 
         
            +
                init_step = 1
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                # Configure file path for the model
         
     | 
| 43 | 
         
            +
                model_dir = models_dir / run_id
         
     | 
| 44 | 
         
            +
                model_dir.mkdir(exist_ok=True, parents=True)
         
     | 
| 45 | 
         
            +
                state_fpath = model_dir / "encoder.pt"
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                # Load any existing model
         
     | 
| 48 | 
         
            +
                if not force_restart:
         
     | 
| 49 | 
         
            +
                    if state_fpath.exists():
         
     | 
| 50 | 
         
            +
                        print("Found existing model \"%s\", loading it and resuming training." % run_id)
         
     | 
| 51 | 
         
            +
                        checkpoint = torch.load(state_fpath)
         
     | 
| 52 | 
         
            +
                        init_step = checkpoint["step"]
         
     | 
| 53 | 
         
            +
                        model.load_state_dict(checkpoint["model_state"])
         
     | 
| 54 | 
         
            +
                        optimizer.load_state_dict(checkpoint["optimizer_state"])
         
     | 
| 55 | 
         
            +
                        optimizer.param_groups[0]["lr"] = learning_rate_init
         
     | 
| 56 | 
         
            +
                    else:
         
     | 
| 57 | 
         
            +
                        print("No model \"%s\" found, starting training from scratch." % run_id)
         
     | 
| 58 | 
         
            +
                else:
         
     | 
| 59 | 
         
            +
                    print("Starting the training from scratch.")
         
     | 
| 60 | 
         
            +
                model.train()
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                # Initialize the visualization environment
         
     | 
| 63 | 
         
            +
                vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
         
     | 
| 64 | 
         
            +
                vis.log_dataset(dataset)
         
     | 
| 65 | 
         
            +
                vis.log_params()
         
     | 
| 66 | 
         
            +
                device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
         
     | 
| 67 | 
         
            +
                vis.log_implementation({"Device": device_name})
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                # Training loop
         
     | 
| 70 | 
         
            +
                profiler = Profiler(summarize_every=10, disabled=False)
         
     | 
| 71 | 
         
            +
                for step, speaker_batch in enumerate(loader, init_step):
         
     | 
| 72 | 
         
            +
                    profiler.tick("Blocking, waiting for batch (threaded)")
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    # Forward pass
         
     | 
| 75 | 
         
            +
                    inputs = torch.from_numpy(speaker_batch.data).to(device)
         
     | 
| 76 | 
         
            +
                    sync(device)
         
     | 
| 77 | 
         
            +
                    profiler.tick("Data to %s" % device)
         
     | 
| 78 | 
         
            +
                    embeds = model(inputs)
         
     | 
| 79 | 
         
            +
                    sync(device)
         
     | 
| 80 | 
         
            +
                    profiler.tick("Forward pass")
         
     | 
| 81 | 
         
            +
                    embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
         
     | 
| 82 | 
         
            +
                    loss, eer = model.loss(embeds_loss)
         
     | 
| 83 | 
         
            +
                    sync(loss_device)
         
     | 
| 84 | 
         
            +
                    profiler.tick("Loss")
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    # Backward pass
         
     | 
| 87 | 
         
            +
                    model.zero_grad()
         
     | 
| 88 | 
         
            +
                    loss.backward()
         
     | 
| 89 | 
         
            +
                    profiler.tick("Backward pass")
         
     | 
| 90 | 
         
            +
                    model.do_gradient_ops()
         
     | 
| 91 | 
         
            +
                    optimizer.step()
         
     | 
| 92 | 
         
            +
                    profiler.tick("Parameter update")
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    # Update visualizations
         
     | 
| 95 | 
         
            +
                    # learning_rate = optimizer.param_groups[0]["lr"]
         
     | 
| 96 | 
         
            +
                    vis.update(loss.item(), eer, step)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    # Draw projections and save them to the backup folder
         
     | 
| 99 | 
         
            +
                    if umap_every != 0 and step % umap_every == 0:
         
     | 
| 100 | 
         
            +
                        print("Drawing and saving projections (step %d)" % step)
         
     | 
| 101 | 
         
            +
                        projection_fpath = model_dir / f"umap_{step:06d}.png"
         
     | 
| 102 | 
         
            +
                        embeds = embeds.detach().cpu().numpy()
         
     | 
| 103 | 
         
            +
                        vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
         
     | 
| 104 | 
         
            +
                        vis.save()
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    # Overwrite the latest version of the model
         
     | 
| 107 | 
         
            +
                    if save_every != 0 and step % save_every == 0:
         
     | 
| 108 | 
         
            +
                        print("Saving the model (step %d)" % step)
         
     | 
| 109 | 
         
            +
                        torch.save({
         
     | 
| 110 | 
         
            +
                            "step": step + 1,
         
     | 
| 111 | 
         
            +
                            "model_state": model.state_dict(),
         
     | 
| 112 | 
         
            +
                            "optimizer_state": optimizer.state_dict(),
         
     | 
| 113 | 
         
            +
                        }, state_fpath)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    # Make a backup
         
     | 
| 116 | 
         
            +
                    if backup_every != 0 and step % backup_every == 0:
         
     | 
| 117 | 
         
            +
                        print("Making a backup (step %d)" % step)
         
     | 
| 118 | 
         
            +
                        backup_fpath = model_dir / f"encoder_{step:06d}.bak"
         
     | 
| 119 | 
         
            +
                        torch.save({
         
     | 
| 120 | 
         
            +
                            "step": step + 1,
         
     | 
| 121 | 
         
            +
                            "model_state": model.state_dict(),
         
     | 
| 122 | 
         
            +
                            "optimizer_state": optimizer.state_dict(),
         
     | 
| 123 | 
         
            +
                        }, backup_fpath)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    profiler.tick("Extras (visualizations, saving)")
         
     | 
    	
        encoder/visualizations.py
    ADDED
    
    | 
         @@ -0,0 +1,179 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from datetime import datetime
         
     | 
| 2 | 
         
            +
            from time import perf_counter as timer
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import umap
         
     | 
| 6 | 
         
            +
            import visdom
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            colormap = np.array([
         
     | 
| 12 | 
         
            +
                [76, 255, 0],
         
     | 
| 13 | 
         
            +
                [0, 127, 70],
         
     | 
| 14 | 
         
            +
                [255, 0, 0],
         
     | 
| 15 | 
         
            +
                [255, 217, 38],
         
     | 
| 16 | 
         
            +
                [0, 135, 255],
         
     | 
| 17 | 
         
            +
                [165, 0, 165],
         
     | 
| 18 | 
         
            +
                [255, 167, 255],
         
     | 
| 19 | 
         
            +
                [0, 255, 255],
         
     | 
| 20 | 
         
            +
                [255, 96, 38],
         
     | 
| 21 | 
         
            +
                [142, 76, 0],
         
     | 
| 22 | 
         
            +
                [33, 0, 127],
         
     | 
| 23 | 
         
            +
                [0, 0, 0],
         
     | 
| 24 | 
         
            +
                [183, 183, 183],
         
     | 
| 25 | 
         
            +
            ], dtype=np.float) / 255
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class Visualizations:
         
     | 
| 29 | 
         
            +
                def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
         
     | 
| 30 | 
         
            +
                    # Tracking data
         
     | 
| 31 | 
         
            +
                    self.last_update_timestamp = timer()
         
     | 
| 32 | 
         
            +
                    self.update_every = update_every
         
     | 
| 33 | 
         
            +
                    self.step_times = []
         
     | 
| 34 | 
         
            +
                    self.losses = []
         
     | 
| 35 | 
         
            +
                    self.eers = []
         
     | 
| 36 | 
         
            +
                    print("Updating the visualizations every %d steps." % update_every)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    # If visdom is disabled TODO: use a better paradigm for that
         
     | 
| 39 | 
         
            +
                    self.disabled = disabled
         
     | 
| 40 | 
         
            +
                    if self.disabled:
         
     | 
| 41 | 
         
            +
                        return
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    # Set the environment name
         
     | 
| 44 | 
         
            +
                    now = str(datetime.now().strftime("%d-%m %Hh%M"))
         
     | 
| 45 | 
         
            +
                    if env_name is None:
         
     | 
| 46 | 
         
            +
                        self.env_name = now
         
     | 
| 47 | 
         
            +
                    else:
         
     | 
| 48 | 
         
            +
                        self.env_name = "%s (%s)" % (env_name, now)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    # Connect to visdom and open the corresponding window in the browser
         
     | 
| 51 | 
         
            +
                    try:
         
     | 
| 52 | 
         
            +
                        self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
         
     | 
| 53 | 
         
            +
                    except ConnectionError:
         
     | 
| 54 | 
         
            +
                        raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
         
     | 
| 55 | 
         
            +
                                        "start it.")
         
     | 
| 56 | 
         
            +
                    # webbrowser.open("http://localhost:8097/env/" + self.env_name)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    # Create the windows
         
     | 
| 59 | 
         
            +
                    self.loss_win = None
         
     | 
| 60 | 
         
            +
                    self.eer_win = None
         
     | 
| 61 | 
         
            +
                    # self.lr_win = None
         
     | 
| 62 | 
         
            +
                    self.implementation_win = None
         
     | 
| 63 | 
         
            +
                    self.projection_win = None
         
     | 
| 64 | 
         
            +
                    self.implementation_string = ""
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def log_params(self):
         
     | 
| 67 | 
         
            +
                    if self.disabled:
         
     | 
| 68 | 
         
            +
                        return
         
     | 
| 69 | 
         
            +
                    from encoder import params_data
         
     | 
| 70 | 
         
            +
                    from encoder import params_model
         
     | 
| 71 | 
         
            +
                    param_string = "<b>Model parameters</b>:<br>"
         
     | 
| 72 | 
         
            +
                    for param_name in (p for p in dir(params_model) if not p.startswith("__")):
         
     | 
| 73 | 
         
            +
                        value = getattr(params_model, param_name)
         
     | 
| 74 | 
         
            +
                        param_string += "\t%s: %s<br>" % (param_name, value)
         
     | 
| 75 | 
         
            +
                    param_string += "<b>Data parameters</b>:<br>"
         
     | 
| 76 | 
         
            +
                    for param_name in (p for p in dir(params_data) if not p.startswith("__")):
         
     | 
| 77 | 
         
            +
                        value = getattr(params_data, param_name)
         
     | 
| 78 | 
         
            +
                        param_string += "\t%s: %s<br>" % (param_name, value)
         
     | 
| 79 | 
         
            +
                    self.vis.text(param_string, opts={"title": "Parameters"})
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def log_dataset(self, dataset: SpeakerVerificationDataset):
         
     | 
| 82 | 
         
            +
                    if self.disabled:
         
     | 
| 83 | 
         
            +
                        return
         
     | 
| 84 | 
         
            +
                    dataset_string = ""
         
     | 
| 85 | 
         
            +
                    dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
         
     | 
| 86 | 
         
            +
                    dataset_string += "\n" + dataset.get_logs()
         
     | 
| 87 | 
         
            +
                    dataset_string = dataset_string.replace("\n", "<br>")
         
     | 
| 88 | 
         
            +
                    self.vis.text(dataset_string, opts={"title": "Dataset"})
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def log_implementation(self, params):
         
     | 
| 91 | 
         
            +
                    if self.disabled:
         
     | 
| 92 | 
         
            +
                        return
         
     | 
| 93 | 
         
            +
                    implementation_string = ""
         
     | 
| 94 | 
         
            +
                    for param, value in params.items():
         
     | 
| 95 | 
         
            +
                        implementation_string += "<b>%s</b>: %s\n" % (param, value)
         
     | 
| 96 | 
         
            +
                        implementation_string = implementation_string.replace("\n", "<br>")
         
     | 
| 97 | 
         
            +
                    self.implementation_string = implementation_string
         
     | 
| 98 | 
         
            +
                    self.implementation_win = self.vis.text(
         
     | 
| 99 | 
         
            +
                        implementation_string,
         
     | 
| 100 | 
         
            +
                        opts={"title": "Training implementation"}
         
     | 
| 101 | 
         
            +
                    )
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                def update(self, loss, eer, step):
         
     | 
| 104 | 
         
            +
                    # Update the tracking data
         
     | 
| 105 | 
         
            +
                    now = timer()
         
     | 
| 106 | 
         
            +
                    self.step_times.append(1000 * (now - self.last_update_timestamp))
         
     | 
| 107 | 
         
            +
                    self.last_update_timestamp = now
         
     | 
| 108 | 
         
            +
                    self.losses.append(loss)
         
     | 
| 109 | 
         
            +
                    self.eers.append(eer)
         
     | 
| 110 | 
         
            +
                    print(".", end="")
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    # Update the plots every <update_every> steps
         
     | 
| 113 | 
         
            +
                    if step % self.update_every != 0:
         
     | 
| 114 | 
         
            +
                        return
         
     | 
| 115 | 
         
            +
                    time_string = "Step time:  mean: %5dms  std: %5dms" % \
         
     | 
| 116 | 
         
            +
                                  (int(np.mean(self.step_times)), int(np.std(self.step_times)))
         
     | 
| 117 | 
         
            +
                    print("\nStep %6d   Loss: %.4f   EER: %.4f   %s" %
         
     | 
| 118 | 
         
            +
                          (step, np.mean(self.losses), np.mean(self.eers), time_string))
         
     | 
| 119 | 
         
            +
                    if not self.disabled:
         
     | 
| 120 | 
         
            +
                        self.loss_win = self.vis.line(
         
     | 
| 121 | 
         
            +
                            [np.mean(self.losses)],
         
     | 
| 122 | 
         
            +
                            [step],
         
     | 
| 123 | 
         
            +
                            win=self.loss_win,
         
     | 
| 124 | 
         
            +
                            update="append" if self.loss_win else None,
         
     | 
| 125 | 
         
            +
                            opts=dict(
         
     | 
| 126 | 
         
            +
                                legend=["Avg. loss"],
         
     | 
| 127 | 
         
            +
                                xlabel="Step",
         
     | 
| 128 | 
         
            +
                                ylabel="Loss",
         
     | 
| 129 | 
         
            +
                                title="Loss",
         
     | 
| 130 | 
         
            +
                            )
         
     | 
| 131 | 
         
            +
                        )
         
     | 
| 132 | 
         
            +
                        self.eer_win = self.vis.line(
         
     | 
| 133 | 
         
            +
                            [np.mean(self.eers)],
         
     | 
| 134 | 
         
            +
                            [step],
         
     | 
| 135 | 
         
            +
                            win=self.eer_win,
         
     | 
| 136 | 
         
            +
                            update="append" if self.eer_win else None,
         
     | 
| 137 | 
         
            +
                            opts=dict(
         
     | 
| 138 | 
         
            +
                                legend=["Avg. EER"],
         
     | 
| 139 | 
         
            +
                                xlabel="Step",
         
     | 
| 140 | 
         
            +
                                ylabel="EER",
         
     | 
| 141 | 
         
            +
                                title="Equal error rate"
         
     | 
| 142 | 
         
            +
                            )
         
     | 
| 143 | 
         
            +
                        )
         
     | 
| 144 | 
         
            +
                        if self.implementation_win is not None:
         
     | 
| 145 | 
         
            +
                            self.vis.text(
         
     | 
| 146 | 
         
            +
                                self.implementation_string + ("<b>%s</b>" % time_string),
         
     | 
| 147 | 
         
            +
                                win=self.implementation_win,
         
     | 
| 148 | 
         
            +
                                opts={"title": "Training implementation"},
         
     | 
| 149 | 
         
            +
                            )
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    # Reset the tracking
         
     | 
| 152 | 
         
            +
                    self.losses.clear()
         
     | 
| 153 | 
         
            +
                    self.eers.clear()
         
     | 
| 154 | 
         
            +
                    self.step_times.clear()
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
         
     | 
| 157 | 
         
            +
                    import matplotlib.pyplot as plt
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    max_speakers = min(max_speakers, len(colormap))
         
     | 
| 160 | 
         
            +
                    embeds = embeds[:max_speakers * utterances_per_speaker]
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    n_speakers = len(embeds) // utterances_per_speaker
         
     | 
| 163 | 
         
            +
                    ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
         
     | 
| 164 | 
         
            +
                    colors = [colormap[i] for i in ground_truth]
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    reducer = umap.UMAP()
         
     | 
| 167 | 
         
            +
                    projected = reducer.fit_transform(embeds)
         
     | 
| 168 | 
         
            +
                    plt.scatter(projected[:, 0], projected[:, 1], c=colors)
         
     | 
| 169 | 
         
            +
                    plt.gca().set_aspect("equal", "datalim")
         
     | 
| 170 | 
         
            +
                    plt.title("UMAP projection (step %d)" % step)
         
     | 
| 171 | 
         
            +
                    if not self.disabled:
         
     | 
| 172 | 
         
            +
                        self.projection_win = self.vis.matplot(plt, win=self.projection_win)
         
     | 
| 173 | 
         
            +
                    if out_fpath is not None:
         
     | 
| 174 | 
         
            +
                        plt.savefig(out_fpath)
         
     | 
| 175 | 
         
            +
                    plt.clf()
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def save(self):
         
     | 
| 178 | 
         
            +
                    if not self.disabled:
         
     | 
| 179 | 
         
            +
                        self.vis.save([self.env_name])
         
     |