Shokoufeh commited on
Commit
2aa6704
1 Parent(s): a2cde0e

Add custom pipeline file

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +82 -0
custom_pipeline.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from transformers import Pipeline
4
+ from librosa import resample
5
+ from soundfile import write
6
+ from sgmse.model import ScoreModel
7
+ from sgmse.util.other import pad_spec
8
+
9
+ class CustomSpeechEnhancementPipeline(Pipeline):
10
+ def __init__(self, model, target_sr=16000, pad_mode="zero_pad", args=None):
11
+ """
12
+ Custom pipeline for speech enhancement using ScoreModel.
13
+
14
+ Args:
15
+ model: The speech enhancement model loaded from a checkpoint (ScoreModel).
16
+ target_sr: Target sample rate for the input audio (default is 16 kHz).
17
+ pad_mode: Padding mode for spectrogram (default is "zero_pad").
18
+ args: Parsed arguments (device, corrector, corrector_steps, snr, etc.).
19
+ """
20
+ super().__init__(model=model)
21
+ self.target_sr = target_sr
22
+ self.pad_mode = pad_mode
23
+ self.args = args
24
+
25
+ def preprocess(self, audio_path):
26
+ # Load the audio file
27
+ y, sr = torchaudio.load(audio_path)
28
+
29
+ # Resample if necessary
30
+ if sr != self.target_sr:
31
+ y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=self.target_sr))
32
+
33
+ # Normalize the audio
34
+ norm_factor = y.abs().max()
35
+ y = y / norm_factor
36
+
37
+ # Prepare the input for the model by transforming to the frequency domain
38
+ Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.args.device))), 0)
39
+ Y = pad_spec(Y, mode=self.pad_mode)
40
+
41
+ return Y, norm_factor, y.size(1) # Return input spec, normalization factor, and original length
42
+
43
+ def _forward(self, model_inputs):
44
+ Y, norm_factor, T_orig = model_inputs
45
+
46
+ # Perform reverse sampling using the model's PC sampler
47
+ sampler = self.model.get_pc_sampler(
48
+ 'reverse_diffusion',
49
+ self.args.corrector,
50
+ Y.to(self.args.device),
51
+ N=self.args.N,
52
+ corrector_steps=self.args.corrector_steps,
53
+ snr=self.args.snr
54
+ )
55
+
56
+ # Get the enhanced speech sample
57
+ sample, _ = sampler()
58
+
59
+ # Convert back to time domain
60
+ x_hat = self.model.to_audio(sample.squeeze(), T_orig)
61
+
62
+ # Renormalize the audio
63
+ x_hat = x_hat * norm_factor
64
+
65
+ return x_hat
66
+
67
+ def postprocess(self, model_outputs):
68
+ # Convert the enhanced output back to NumPy for further processing or saving
69
+ return model_outputs.cpu().numpy()
70
+
71
+ def pad_spec(self, Y):
72
+ """
73
+ Apply padding to the spectrogram as per the model's required padding mode.
74
+
75
+ Args:
76
+ Y: Input spectrogram tensor.
77
+
78
+ Returns:
79
+ Padded spectrogram.
80
+ """
81
+ # Implement padding as per the provided mode
82
+ return torch.nn.functional.pad(Y, (0, 0, 0, 1), mode=self.pad_mode)