File size: 3,800 Bytes
4b8361a
0030bc6
4b8361a
 
 
0030bc6
4b8361a
0030bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b8361a
 
0030bc6
 
 
 
 
 
 
4b8361a
0030bc6
 
 
 
 
 
 
 
 
 
4b8361a
0030bc6
 
 
 
 
 
 
 
4b8361a
0030bc6
 
 
 
 
4b8361a
 
0030bc6
4b8361a
0030bc6
4b8361a
0030bc6
4b8361a
 
 
 
0030bc6
 
 
4b8361a
0030bc6
 
 
4b8361a
0030bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b8361a
 
0030bc6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torchaudio
from torchaudio import transforms as taT, functional as taF
import torch.nn as nn

NOISE_PATH = "data/augmentation/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"

class AudioTrainingPipeline(torch.nn.Module):
    def __init__(self, 
            input_freq=16000,
            resample_freq=16000,
            expected_duration=6,
            freq_mask_size=10,
            time_mask_size=80,
            mask_count = 2,
            snr_mean=6.0):
        super().__init__()
        self.input_freq = input_freq
        self.snr_mean = snr_mean
        self.mask_count = mask_count
        self.noise = self.get_noise()
        self.resample = taT.Resample(input_freq,resample_freq)
        self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
        self.audio_to_spectrogram = AudioToSpectrogram(
            sample_rate=resample_freq,
        )
        self.freq_mask = taT.FrequencyMasking(freq_mask_size)
        self.time_mask = taT.TimeMasking(time_mask_size)
        

    def get_noise(self) -> torch.Tensor:
        noise, sr = torchaudio.load(NOISE_PATH)
        if noise.shape[0] > 1:
            noise = noise.mean(0, keepdim=True)
        if sr != self.input_freq:
            noise = taF.resample(noise,sr, self.input_freq)
        return noise

    def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
        num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
        noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
        noise_power = noise.norm(p=2)
        signal_power = waveform.norm(p=2)
        snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
        snr = torch.exp(snr_db / 10)
        scale = snr * noise_power / signal_power
        noisy_waveform = (scale * waveform + noise) / 2
        return noisy_waveform

    def forward(self, waveform:torch.Tensor) -> torch.Tensor:
        try:
            waveform = self.resample(waveform)
        except:
            print("oops")
        waveform = self.preprocess_waveform(waveform)
        waveform = self.add_noise(waveform)
        spec = self.audio_to_spectrogram(waveform)

        # Spectrogram augmentation
        for _ in range(self.mask_count):
            spec = self.freq_mask(spec)
            spec = self.time_mask(spec)
        return spec


class WaveformPreprocessing(torch.nn.Module):

    def __init__(self, expected_sample_length:int):
        super().__init__()
        self.expected_sample_length = expected_sample_length
        


    def forward(self, waveform:torch.Tensor) -> torch.Tensor:
        # Take out extra channels
        if waveform.shape[0] > 1:
            waveform = waveform.mean(0, keepdim=True)

        # ensure it is the correct length
        waveform = self._rectify_duration(waveform)
        return waveform


    def _rectify_duration(self,waveform:torch.Tensor):
        expected_samples = self.expected_sample_length
        sample_count = waveform.shape[1]
        if expected_samples == sample_count:
            return waveform
        elif expected_samples > sample_count:
            pad_amount = expected_samples - sample_count
            return torch.nn.functional.pad(waveform, (0, pad_amount),mode="constant", value=0.0)
        else:
            return waveform[:,:expected_samples]


class AudioToSpectrogram(torch.nn.Module):
    def __init__(
        self,
        sample_rate=16000,
    ):
        super().__init__()

        self.spec = taT.MelSpectrogram(sample_rate=sample_rate, n_mels=128, n_fft=1024) # TODO: Change mels to 64
        self.to_db = taT.AmplitudeToDB()

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        spectrogram = self.spec(waveform)
        spectrogram = self.to_db(spectrogram)
        return spectrogram